From c301ca92f80d858c63f7545c3d7a5b7fef600cad Mon Sep 17 00:00:00 2001 From: Zhichao Date: Tue, 30 Jul 2024 20:44:11 +0800 Subject: [PATCH] feat: add `function_name` to each `send_inference_request` request (#186) add function_name to each send_inference_request --- vision_agent/tools/tools.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 27a2f10a..c0af8b21 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -106,6 +106,7 @@ def grounding_dino( "visual_grounding" if model_size == "large" else "visual_grounding_tiny" ), "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, + "function_name": "grounding_dino", } data: Dict[str, Any] = send_inference_request(request_data, "tools") return_data = [] @@ -161,6 +162,7 @@ def owl_v2( "image": image_b64, "tool": "open_vocab_detection", "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, + "function_name": "owl_v2", } data: Dict[str, Any] = send_inference_request(request_data, "tools") return_data = [] @@ -225,6 +227,7 @@ def grounding_sam( "image": image_b64, "tool": "visual_grounding_segment", "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, + "function_name": "grounding_sam", } data: Dict[str, Any] = send_inference_request(request_data, "tools") return_data = [] @@ -364,6 +367,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: data = { "image": image_b64, "tool": "zero_shot_counting", + "function_name": "loca_zero_shot_counting", } resp_data = send_inference_request(data, "tools") resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) @@ -399,6 +403,7 @@ def loca_visual_prompt_counting( "image": image_b64, "prompt": bbox_str, "tool": "few_shot_counting", + "function_name": "loca_visual_prompt_counting", } resp_data = send_inference_request(data, "tools") resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) @@ -428,6 +433,7 @@ def florencev2_roberta_vqa(prompt: str, image: np.ndarray) -> str: "image": image_b64, "prompt": prompt, "tool": "image_question_answering_with_context", + "function_name": "florencev2_roberta_vqa", } answer = send_inference_request(data, "tools") @@ -457,6 +463,7 @@ def git_vqa_v2(prompt: str, image: np.ndarray) -> str: "image": image_b64, "prompt": prompt, "tool": "image_question_answering", + "function_name": "git_vqa_v2", } answer = send_inference_request(data, "tools") @@ -487,6 +494,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: "prompt": ",".join(classes), "image": image_b64, "tool": "closed_set_image_classification", + "function_name": "clip", } resp_data = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] @@ -514,6 +522,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]: data = { "image": image_b64, "tool": "image_classification", + "function_name": "vit_image_classification", } resp_data = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] @@ -541,6 +550,7 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: data = { "image": image_b64, "tool": "nsfw_image_classification", + "function_name": "vit_nsfw_classification", } resp_data = send_inference_request(data, "tools") resp_data["scores"] = round(resp_data["scores"], 4) @@ -567,6 +577,7 @@ def blip_image_caption(image: np.ndarray) -> str: data = { "image": image_b64, "tool": "image_captioning", + "function_name": "blip_image_caption", } answer = send_inference_request(data, "tools") @@ -595,6 +606,7 @@ def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) -> "image": image_b64, "tool": "florence2_image_captioning", "detail_caption": detail_caption, + "function_name": "florencev2_image_caption", } answer = send_inference_request(data, "tools") @@ -630,6 +642,7 @@ def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]: data = { "image": image_b64, "tool": "object_detection", + "function_name": "florencev2_object_detection", } answer = send_inference_request(data, "tools") @@ -686,6 +699,7 @@ def detr_segmentation(image: np.ndarray) -> List[Dict[str, Any]]: data = { "image": image_b64, "tool": "panoptic_segmentation", + "function_name": "detr_segmentation", } answer = send_inference_request(data, "tools") @@ -728,6 +742,7 @@ def depth_anything_v2(image: np.ndarray) -> np.ndarray: data = { "image": image_b64, "tool": "generate_depth", + "function_name": "depth_anything_v2", } answer = send_inference_request(data, "tools") @@ -759,6 +774,7 @@ def generate_soft_edge_image(image: np.ndarray) -> np.ndarray: data = { "image": image_b64, "tool": "generate_hed", + "function_name": "generate_soft_edge_image", } answer = send_inference_request(data, "tools") @@ -791,6 +807,7 @@ def dpt_hybrid_midas(image: np.ndarray) -> np.ndarray: data = { "image": image_b64, "tool": "generate_normal", + "function_name": "dpt_hybrid_midas", } answer = send_inference_request(data, "tools") @@ -822,6 +839,7 @@ def generate_pose_image(image: np.ndarray) -> np.ndarray: data = { "image": image_b64, "tool": "generate_pose", + "function_name": "generate_pose_image", } answer = send_inference_request(data, "tools") @@ -862,6 +880,7 @@ def template_match( "image": image_b64, "template": template_image_b64, "tool": "template_match", + "function_name": "template_match", } answer = send_inference_request(data, "tools")