diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 57d536bd..68f5d396 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -61,9 +61,10 @@ def test_object_detection(): img = ski.data.coins() result = florencev2_object_detection( image=img, + prompt="coin", ) - assert len(result) == 24 - assert [res["label"] for res in result] == ["coin"] * 24 + assert len(result) == 25 + assert [res["label"] for res in result] == ["coin"] * 25 def test_template_match(): @@ -118,7 +119,7 @@ def test_nsfw_classification(): result = vit_nsfw_classification( image=img, ) - assert result["labels"] == "normal" + assert result["label"] == "normal" def test_image_caption() -> None: diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 664466bc..e1fa69c3 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -16,7 +16,8 @@ _LOGGER = logging.getLogger(__name__) _LND_API_KEY = LandingaiAPIKey().api_key -_LND_API_URL = "https://api.landing.ai/v1/agent" +_LND_API_URL = "https://api.landing.ai/v1/agent/model" +_LND_API_URL_v2 = "https://api.landing.ai/v1/tools" class ToolCallTrace(BaseModel): @@ -27,13 +28,13 @@ class ToolCallTrace(BaseModel): def send_inference_request( - payload: Dict[str, Any], endpoint_name: str + payload: Dict[str, Any], endpoint_name: str, v2: bool = False ) -> Dict[str, Any]: try: if runtime_tag := os.environ.get("RUNTIME_TAG", ""): payload["runtime_tag"] = runtime_tag - url = f"{_LND_API_URL}/model/{endpoint_name}" + url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}" if "TOOL_ENDPOINT_URL" in os.environ: url = os.environ["TOOL_ENDPOINT_URL"] diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index a78457ed..fc4a59f4 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -126,7 +126,6 @@ def owl_v2( prompt: str, image: np.ndarray, box_threshold: float = 0.10, - iou_threshold: float = 0.10, ) -> List[Dict[str, Any]]: """'owl_v2' is a tool that can detect and count multiple objects given a text prompt such as category names or referring expressions. The categories in text prompt @@ -138,8 +137,6 @@ def owl_v2( image (np.ndarray): The image to ground the prompt to. box_threshold (float, optional): The threshold for the box detection. Defaults to 0.10. - iou_threshold (float, optional): The threshold for the Intersection over Union - (IoU). Defaults to 0.10. Returns: List[Dict[str, Any]]: A list of dictionaries containing the score, label, and @@ -159,22 +156,22 @@ def owl_v2( image_size = image.shape[:2] image_b64 = convert_to_b64(image) request_data = { - "prompt": prompt, + "prompts": prompt.split("."), "image": image_b64, - "tool": "open_vocab_detection", - "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, + "confidence": box_threshold, "function_name": "owl_v2", } - data: Dict[str, Any] = send_inference_request(request_data, "tools") + data: Dict[str, Any] = send_inference_request(request_data, "owlv2", v2=True) return_data = [] - for i in range(len(data["bboxes"])): - return_data.append( - { - "score": round(data["scores"][i], 2), - "label": data["labels"][i].strip(), - "bbox": normalize_bbox(data["bboxes"][i], image_size), - } - ) + if data is not None: + for elt in data: + return_data.append( + { + "bbox": normalize_bbox(elt["bbox"], image_size), # type: ignore + "label": elt["label"], # type: ignore + "score": round(elt["score"], 2), # type: ignore + } + ) return return_data @@ -367,11 +364,10 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: image_b64 = convert_to_b64(image) data = { "image": image_b64, - "tool": "zero_shot_counting", "function_name": "loca_zero_shot_counting", } - resp_data = send_inference_request(data, "tools") - resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + resp_data = send_inference_request(data, "loca", v2=True) + resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8) return resp_data @@ -397,17 +393,15 @@ def loca_visual_prompt_counting( image_size = get_image_size(image) bbox = visual_prompt["bbox"] - bbox_str = ", ".join(map(str, denormalize_bbox(bbox, image_size))) image_b64 = convert_to_b64(image) data = { "image": image_b64, - "prompt": bbox_str, - "tool": "few_shot_counting", + "bbox": list(map(int, denormalize_bbox(bbox, image_size))), "function_name": "loca_visual_prompt_counting", } - resp_data = send_inference_request(data, "tools") - resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + resp_data = send_inference_request(data, "loca", v2=True) + resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8) return resp_data @@ -432,13 +426,12 @@ def florencev2_roberta_vqa(prompt: str, image: np.ndarray) -> str: image_b64 = convert_to_b64(image) data = { "image": image_b64, - "prompt": prompt, - "tool": "image_question_answering_with_context", + "question": prompt, "function_name": "florencev2_roberta_vqa", } - answer = send_inference_request(data, "tools") - return answer["text"][0] # type: ignore + answer = send_inference_request(data, "florence2-qa", v2=True) + return answer # type: ignore def git_vqa_v2(prompt: str, image: np.ndarray) -> str: @@ -544,17 +537,16 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: Example ------- >>> vit_nsfw_classification(image) - {"labels": "normal", "scores": 0.68}, + {"label": "normal", "scores": 0.68}, """ image_b64 = convert_to_b64(image) data = { "image": image_b64, - "tool": "nsfw_image_classification", "function_name": "vit_nsfw_classification", } - resp_data = send_inference_request(data, "tools") - resp_data["scores"] = round(resp_data["scores"], 4) + resp_data = send_inference_request(data, "nsfw-classification", v2=True) + resp_data["score"] = round(resp_data["score"], 4) return resp_data @@ -603,21 +595,21 @@ def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) -> 'This image contains a cat sitting on a table with a bowl of milk.' """ image_b64 = convert_to_b64(image) + task = "" if detail_caption else "" data = { "image": image_b64, - "tool": "florence2_image_captioning", - "detail_caption": detail_caption, + "task": task, "function_name": "florencev2_image_caption", } - answer = send_inference_request(data, "tools") - return answer["text"][0] # type: ignore + answer = send_inference_request(data, "florence2", v2=True) + return answer[task] # type: ignore -def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]: - """'florencev2_object_detection' is a tool that can detect common objects in an - image without any text prompt or thresholding. It returns a list of detected objects - as labels and their location as bounding boxes. +def florencev2_object_detection(image: np.ndarray, prompt: str) -> List[Dict[str, Any]]: + """'florencev2_object_detection' is a tool that can detect objects given a text + prompt such as a phrase or class names separated by commas. It returns a list of + detected objects as labels and their location as bounding boxes with score of 1.0. Parameters: image (np.ndarray): The image to used to detect objects @@ -631,29 +623,30 @@ def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]: Example ------- - >>> florencev2_object_detection(image) + >>> florencev2_object_detection(image, 'person looking at a coyote') [ - {'score': 1.0, 'label': 'window', 'bbox': [0.1, 0.11, 0.35, 0.4]}, - {'score': 1.0, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5}, - {'score': 1.0, 'label': 'person', 'bbox': [0.34, 0.21, 0.85, 0.5}, + {'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5}, ] """ image_size = image.shape[:2] image_b64 = convert_to_b64(image) data = { "image": image_b64, - "tool": "object_detection", + "task": "", + "prompt": prompt, "function_name": "florencev2_object_detection", } - answer = send_inference_request(data, "tools") + detections = send_inference_request(data, "florence2", v2=True) + detections = detections[""] return_data = [] - for i in range(len(answer["bboxes"])): + for i in range(len(detections["bboxes"])): return_data.append( { - "score": round(answer["scores"][i], 2), - "label": answer["labels"][i], - "bbox": normalize_bbox(answer["bboxes"][i], image_size), + "score": 1.0, + "label": detections["labels"][i], + "bbox": normalize_bbox(detections["bboxes"][i], image_size), } ) return return_data @@ -742,13 +735,16 @@ def depth_anything_v2(image: np.ndarray) -> np.ndarray: image_b64 = convert_to_b64(image) data = { "image": image_b64, - "tool": "generate_depth", "function_name": "depth_anything_v2", } - answer = send_inference_request(data, "tools") - return_data = np.array(b64_to_pil(answer["masks"][0]).convert("L")) - return return_data + depth_map = send_inference_request(data, "depth-anything-v2", v2=True) + depth_map_np = np.array(depth_map["map"]) + depth_map_np = (depth_map_np - depth_map_np.min()) / ( + depth_map_np.max() - depth_map_np.min() + ) + depth_map_np = (255 * depth_map_np).astype(np.uint8) + return depth_map_np def generate_soft_edge_image(image: np.ndarray) -> np.ndarray: @@ -839,12 +835,11 @@ def generate_pose_image(image: np.ndarray) -> np.ndarray: image_b64 = convert_to_b64(image) data = { "image": image_b64, - "tool": "generate_pose", "function_name": "generate_pose_image", } - answer = send_inference_request(data, "tools") - return_data = np.array(b64_to_pil(answer["masks"][0]).convert("RGB")) + pos_img = send_inference_request(data, "pose-detector", v2=True) + return_data = np.array(b64_to_pil(pos_img["data"]).convert("RGB")) return return_data @@ -1253,7 +1248,6 @@ def overlay_heat_map( loca_visual_prompt_counting, florencev2_roberta_vqa, florencev2_image_caption, - florencev2_object_detection, detr_segmentation, depth_anything_v2, generate_soft_edge_image, diff --git a/vision_agent/utils/type_defs.py b/vision_agent/utils/type_defs.py index a9398ee5..83ab8f62 100644 --- a/vision_agent/utils/type_defs.py +++ b/vision_agent/utils/type_defs.py @@ -14,7 +14,7 @@ class LandingaiAPIKey(BaseSettings): """ api_key: str = Field( - default="land_sk_fnmSzD0ksknSfvhyD8UGu9R4ss3bKfLL1Im5gb6tDQTy2z1Oy5", + default="land_sk_zKvyPcPV2bVoq7q87KwduoerAxuQpx33DnqP8M1BliOCiZOSoI", alias="LANDINGAI_API_KEY", description="The API key of LandingAI.", )