From d0076f7841ce8d0b8dc474a1e1b9bd2abfd0ecb8 Mon Sep 17 00:00:00 2001 From: Shankar <90070882+shankar-landing-ai@users.noreply.github.com> Date: Mon, 15 Apr 2024 13:49:17 -0700 Subject: [PATCH] Add image caption tool (#52) added image caption tool --- vision_agent/agent/vision_agent.py | 2 +- vision_agent/tools/__init__.py | 1 + vision_agent/tools/tools.py | 69 ++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 2f1d58b4..67bce0a1 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -351,7 +351,7 @@ def __init__( task_model: Optional[Union[LLM, LMM]] = None, answer_model: Optional[Union[LLM, LMM]] = None, reflect_model: Optional[Union[LLM, LMM]] = None, - max_retries: int = 3, + max_retries: int = 2, verbose: bool = False, report_progress_callback: Optional[Callable[[str], None]] = None, ): diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 20dc503a..d2dea3e2 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -9,6 +9,7 @@ ExtractFrames, GroundingDINO, GroundingSAM, + ImageCaption, SegArea, SegIoU, Tool, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 12450753..2c686c43 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -144,6 +144,74 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: return resp_json["data"] # type: ignore +class ImageCaption(Tool): + r"""ImageCaption is a tool that can caption an image based on its contents + or tags. + + Example + ------- + >>> import vision_agent as va + >>> caption = va.tools.ImageCaption() + >>> caption("image1.jpg") + {'text': ['a box of orange and white socks']} + """ + + _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" + + name = "image_caption_" + description = "'image_caption_' is a tool that can caption an image based on its contents or tags. It returns a text describing the image" + usage = { + "required_parameters": [ + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you describe this image ? Image name: cat.jpg", + "parameters": {"image": "cat.jpg"}, + }, + { + "scenario": "Can you caption this image with their main contents ? Image name: cat_dog.jpg", + "parameters": {"image": "cat_dog.jpg"}, + }, + { + "scenario": "Can you build me a image captioning tool ? Image name: shirts.jpg", + "parameters": { + "image": "shirts.jpg", + }, + }, + ], + } + + # TODO: Add support for input multiple images, which aligns with the output type. + def __call__(self, image: Union[str, ImageType]) -> Dict: + """Invoke the Image captioning model. + + Parameters: + image: the input image to caption. + + Returns: + A list of dictionaries containing the labels and scores. Each dictionary contains the classification result for an image. E.g. [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}] + """ + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "image_captioning", + } + res = requests.post( + self._ENDPOINT, + headers={"Content-Type": "application/json"}, + json=data, + ) + resp_json: Dict[str, Any] = res.json() + if ( + "statusCode" in resp_json and resp_json["statusCode"] != 200 + ) or "statusCode" not in resp_json: + _LOGGER.error(f"Request failed: {resp_json}") + raise ValueError(f"Request failed: {resp_json}") + + return resp_json["data"] # type: ignore + + class GroundingDINO(Tool): r"""Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions. @@ -631,6 +699,7 @@ def __call__(self, equation: str) -> float: [ NoOp, CLIP, + ImageCaption, GroundingDINO, AgentGroundingSAM, ExtractFrames,