From f8e05ee2559604f0fa63ee66c009dc6ed43b12e9 Mon Sep 17 00:00:00 2001 From: Dayanne Fernandes Date: Mon, 2 Sep 2024 20:58:40 -0300 Subject: [PATCH] linter --- vision_agent/tools/tool_utils.py | 12 ++++++------ vision_agent/tools/tools.py | 24 ++++++++++++++---------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 30ac659b..a14443bd 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -35,7 +35,7 @@ def send_inference_request( files: Optional[List[Tuple[Any, ...]]] = None, v2: bool = False, metadata_payload: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: +) -> Any: # TODO: runtime_tag and function_name should be metadata_payload and now included # in the service payload if runtime_tag := os.environ.get("RUNTIME_TAG", ""): @@ -70,11 +70,11 @@ def send_inference_request( def send_task_inference_request( payload: Dict[str, Any], - endpoint_name: str, + task_name: str, files: Optional[List[Tuple[Any, ...]]] = None, metadata: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - url = f"{_LND_API_URL_v2}/{endpoint_name}" +) -> Any: + url = f"{_LND_API_URL_v2}/{task_name}" headers = {"apikey": _LND_API_KEY} session = _create_requests_session( url=url, @@ -201,7 +201,7 @@ def _call_post( session: Session, files: Optional[List[Tuple[Any, ...]]] = None, function_name: str = "unknown", -) -> dict[str, Any]: +) -> Any: try: tool_call_trace = ToolCallTrace( endpoint_url=url, @@ -238,4 +238,4 @@ def _call_post( def filter_bboxes_by_threshold( bboxes: BoundingBoxes, threshold: float ) -> BoundingBoxes: - return list(map(lambda bbox: bbox["score"] >= threshold, bboxes)) + return list(filter(lambda bbox: bbox.score >= threshold, bboxes)) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 960f51f9..e0961398 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -458,7 +458,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: "image": image_b64, "function_name": "loca_zero_shot_counting", } - resp_data = send_inference_request(data, "loca", v2=True) + resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True) resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8) return resp_data @@ -501,7 +501,7 @@ def loca_visual_prompt_counting( "bbox": list(map(int, denormalize_bbox(bbox, image_size))), "function_name": "loca_visual_prompt_counting", } - resp_data = send_inference_request(data, "loca", v2=True) + resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True) resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8) return resp_data @@ -542,12 +542,13 @@ def countgd_counting( files = [("image", buffer_bytes)] payload = {"prompts": [prompt], "model": "countgd"} metadata = {"function_name": "countgd_counting"} - resp_data: List[Dict[str, Any]] = send_task_inference_request( + resp_data = send_task_inference_request( payload, "text-to-object-detection", files=files, metadata=metadata ) bboxes_per_frame = resp_data[0] bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] - return filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + return [bbox.model_dump() for bbox in filtered_bboxes] def countgd_example_based_counting( @@ -591,14 +592,15 @@ def countgd_example_based_counting( visual_prompts = [ denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts ] - payload = {"visual_prompts": json.loads(visual_prompts), "model": "countgd"} + payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"} metadata = {"function_name": "countgd_example_based_counting"} - resp_data: List[Dict[str, Any]] = send_task_inference_request( + resp_data = send_task_inference_request( payload, "visual-prompts-to-object-detection", files=files, metadata=metadata ) bboxes_per_frame = resp_data[0] bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] - return filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + return [bbox.model_dump() for bbox in filtered_bboxes] def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: @@ -746,7 +748,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: "tool": "closed_set_image_classification", "function_name": "clip", } - resp_data = send_inference_request(data, "tools") + resp_data: dict[str, Any] = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] return resp_data @@ -774,7 +776,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]: "tool": "image_classification", "function_name": "vit_image_classification", } - resp_data = send_inference_request(data, "tools") + resp_data: dict[str, Any] = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] return resp_data @@ -801,7 +803,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: "image": image_b64, "function_name": "vit_nsfw_classification", } - resp_data = send_inference_request(data, "nsfw-classification", v2=True) + resp_data: dict[str, Any] = send_inference_request( + data, "nsfw-classification", v2=True + ) resp_data["score"] = round(resp_data["score"], 4) return resp_data