From 87b2193ca3b2352544f3077ce2159499d0c8db72 Mon Sep 17 00:00:00 2001 From: Shankar <90070882+shankar-landing-ai@users.noreply.github.com> Date: Wed, 24 Apr 2024 10:59:49 -0700 Subject: [PATCH] adding vqa tool for agent and also for general purposes (#63) * adding vqa tool for agent and also for general purposes * fix linting issues * fix linting --- vision_agent/llm/llm.py | 5 ++ vision_agent/lmm/lmm.py | 17 +++- vision_agent/tools/__init__.py | 6 +- vision_agent/tools/tools.py | 145 ++++++++++++++++++++++++++++++++- 4 files changed, 164 insertions(+), 9 deletions(-) diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 3f83c269..842d27aa 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -131,6 +131,11 @@ def generate_segmentor(self, question: str) -> Callable: def generate_zero_shot_counter(self, question: str) -> Callable: return lambda x: ZeroShotCounting()(**{"image": x}) + def generate_image_qa_tool(self, question: str) -> Callable: + from vision_agent.tools import ImageQuestionAnswering + + return lambda x: ImageQuestionAnswering()(**{"prompt": question, "image": x}) + class AzureOpenAILLM(OpenAILLM): def __init__( diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 06ce94a2..a1fcc3c2 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -11,11 +11,7 @@ from vision_agent.tools import ( CHOOSE_PARAMS, - CLIP, SYSTEM_PROMPT, - GroundingDINO, - GroundingSAM, - ZeroShotCounting, ) _LOGGER = logging.getLogger(__name__) @@ -205,6 +201,8 @@ def generate( return cast(str, response.choices[0].message.content) def generate_classifier(self, question: str) -> Callable: + from vision_agent.tools import CLIP + api_doc = CLIP.description + "\n" + str(CLIP.usage) prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question) response = self.client.chat.completions.create( @@ -228,6 +226,8 @@ def generate_classifier(self, question: str) -> Callable: return lambda x: CLIP()(**{"prompt": params["prompt"], "image": x}) def generate_detector(self, question: str) -> Callable: + from vision_agent.tools import GroundingDINO + api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage) prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question) response = self.client.chat.completions.create( @@ -251,6 +251,8 @@ def generate_detector(self, question: str) -> Callable: return lambda x: GroundingDINO()(**{"prompt": params["prompt"], "image": x}) def generate_segmentor(self, question: str) -> Callable: + from vision_agent.tools import GroundingSAM + api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage) prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question) response = self.client.chat.completions.create( @@ -274,8 +276,15 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x}) def generate_zero_shot_counter(self, question: str) -> Callable: + from vision_agent.tools import ZeroShotCounting + return lambda x: ZeroShotCounting()(**{"image": x}) + def generate_image_qa_tool(self, question: str) -> Callable: + from vision_agent.tools import ImageQuestionAnswering + + return lambda x: ImageQuestionAnswering()(**{"prompt": question, "image": x}) + class AzureOpenAILMM(OpenAILMM): def __init__( diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 67248156..60870b56 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -12,10 +12,12 @@ GroundingDINO, GroundingSAM, ImageCaption, + ZeroShotCounting, + VisualPromptCounting, + VisualQuestionAnswering, + ImageQuestionAnswering, SegArea, SegIoU, Tool, - VisualPromptCounting, - ZeroShotCounting, register_tool, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index a661aeb0..3bf2bfbf 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -19,6 +19,7 @@ ) from vision_agent.tools.video import extract_frames_from_video from vision_agent.type_defs import LandingaiAPIKey +from vision_agent.lmm import OpenAILMM _LOGGER = logging.getLogger(__name__) _LND_API_KEY = LandingaiAPIKey().api_key @@ -502,7 +503,7 @@ class ZeroShotCounting(Tool): # TODO: Add support for input multiple images, which aligns with the output type. def __call__(self, image: Union[str, ImageType]) -> Dict: - """Invoke the Image captioning model. + """Invoke the Zero shot counting model. Parameters: image: the input image. @@ -566,7 +567,7 @@ class VisualPromptCounting(Tool): # TODO: Add support for input multiple images, which aligns with the output type. def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict: - """Invoke the Image captioning model. + """Invoke the few shot counting model. Parameters: image: the input image. @@ -587,6 +588,144 @@ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict: return _send_inference_request(data, "tools") +class VisualQuestionAnswering(Tool): + r"""VisualQuestionAnswering is a tool that can explain contents of an image and answer questions about the same + + Example + ------- + >>> import vision_agent as va + >>> vqa_tool = va.tools.VisualQuestionAnswering() + >>> vqa_tool(image="image1.jpg", prompt="describe this image in detail") + {'text': "The image contains a cat sitting on a table with a bowl of milk."} + """ + + name = "visual_question_answering_" + description = "'visual_question_answering_' is a tool that can describe the contents of the image and it can also answer basic questions about the image." + + usage = { + "required_parameters": [ + {"name": "image", "type": "str"}, + {"name": "prompt", "type": "str"}, + ], + "examples": [ + { + "scenario": "Describe this image in detail. Image name: cat.jpg", + "parameters": { + "image": "cats.jpg", + "prompt": "Describe this image in detail", + }, + }, + { + "scenario": "Can you help me with this street sign in this image ? What does it say ? Image name: sign.jpg", + "parameters": { + "image": "sign.jpg", + "prompt": "Can you help me with this street sign ? What does it say ?", + }, + }, + { + "scenario": "Describe the weather in the image for me ? Image name: weather.jpg", + "parameters": { + "image": "weather.jpg", + "prompt": "Describe the weather in the image for me ", + }, + }, + { + "scenario": "Which 2 are the least frequent bins in this histogram ? Image name: chart.jpg", + "parameters": { + "image": "chart.jpg", + "prompt": "Which 2 are the least frequent bins in this histogram", + }, + }, + ], + } + + def __call__(self, image: str, prompt: str) -> Dict: + """Invoke the visual question answering model. + + Parameters: + image: the input image. + + Returns: + A dictionary containing the key 'text' and the answer to the prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'} + """ + + gpt = OpenAILMM() + return {"text": gpt(input=prompt, images=[image])} + + +class ImageQuestionAnswering(Tool): + r"""ImageQuestionAnswering is a tool that can explain contents of an image and answer questions about the same + It is same as VisualQuestionAnswering but this tool is not used by agents. It is used when user requests a tool for VQA using generate_image_qa_tool function. + It is also useful if the user wants the data to be not exposed to OpenAI endpoints + + Example + ------- + >>> import vision_agent as va + >>> vqa_tool = va.tools.ImageQuestionAnswering() + >>> vqa_tool(image="image1.jpg", prompt="describe this image in detail") + {'text': "The image contains a cat sitting on a table with a bowl of milk."} + """ + + name = "image_question_answering_" + description = "'image_question_answering_' is a tool that can describe the contents of the image and it can also answer basic questions about the image." + + usage = { + "required_parameters": [ + {"name": "image", "type": "str"}, + {"name": "prompt", "type": "str"}, + ], + "examples": [ + { + "scenario": "Describe this image in detail. Image name: cat.jpg", + "parameters": { + "image": "cats.jpg", + "prompt": "Describe this image in detail", + }, + }, + { + "scenario": "Can you help me with this street sign in this image ? What does it say ? Image name: sign.jpg", + "parameters": { + "image": "sign.jpg", + "prompt": "Can you help me with this street sign ? What does it say ?", + }, + }, + { + "scenario": "Describe the weather in the image for me ? Image name: weather.jpg", + "parameters": { + "image": "weather.jpg", + "prompt": "Describe the weather in the image for me ", + }, + }, + { + "scenario": "Can you generate an image question answering tool ? Image name: chart.jpg, prompt: Which 2 are the least frequent bins in this histogram", + "parameters": { + "image": "chart.jpg", + "prompt": "Which 2 are the least frequent bins in this histogram", + }, + }, + ], + } + + def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict: + """Invoke the visual question answering model. + + Parameters: + image: the input image. + + Returns: + A dictionary containing the key 'text' and the answer to the prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'} + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "prompt": prompt, + "tool": "image_question_answering", + } + + return _send_inference_request(data, "tools") + + class Crop(Tool): r"""Crop crops an image given a bounding box and returns a file name of the cropped image.""" @@ -944,11 +1083,11 @@ def __call__(self, equation: str) -> float: [ NoOp, CLIP, - ImageCaption, GroundingDINO, AgentGroundingSAM, ZeroShotCounting, VisualPromptCounting, + VisualQuestionAnswering, AgentDINOv, ExtractFrames, Crop,