Skip to content

Commit

Permalink
adding vqa tool for agent and also for general purposes (#63)
Browse files Browse the repository at this point in the history
* adding vqa tool for agent and also for general purposes

* fix linting issues

* fix linting
  • Loading branch information
shankar-vision-eng authored Apr 24, 2024
1 parent c0dde36 commit 87b2193
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 9 deletions.
5 changes: 5 additions & 0 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def generate_segmentor(self, question: str) -> Callable:
def generate_zero_shot_counter(self, question: str) -> Callable:
return lambda x: ZeroShotCounting()(**{"image": x})

def generate_image_qa_tool(self, question: str) -> Callable:
from vision_agent.tools import ImageQuestionAnswering

return lambda x: ImageQuestionAnswering()(**{"prompt": question, "image": x})


class AzureOpenAILLM(OpenAILLM):
def __init__(
Expand Down
17 changes: 13 additions & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@

from vision_agent.tools import (
CHOOSE_PARAMS,
CLIP,
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ZeroShotCounting,
)

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -205,6 +201,8 @@ def generate(
return cast(str, response.choices[0].message.content)

def generate_classifier(self, question: str) -> Callable:
from vision_agent.tools import CLIP

api_doc = CLIP.description + "\n" + str(CLIP.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
Expand All @@ -228,6 +226,8 @@ def generate_classifier(self, question: str) -> Callable:
return lambda x: CLIP()(**{"prompt": params["prompt"], "image": x})

def generate_detector(self, question: str) -> Callable:
from vision_agent.tools import GroundingDINO

api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
Expand All @@ -251,6 +251,8 @@ def generate_detector(self, question: str) -> Callable:
return lambda x: GroundingDINO()(**{"prompt": params["prompt"], "image": x})

def generate_segmentor(self, question: str) -> Callable:
from vision_agent.tools import GroundingSAM

api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
Expand All @@ -274,8 +276,15 @@ def generate_segmentor(self, question: str) -> Callable:
return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})

def generate_zero_shot_counter(self, question: str) -> Callable:
from vision_agent.tools import ZeroShotCounting

return lambda x: ZeroShotCounting()(**{"image": x})

def generate_image_qa_tool(self, question: str) -> Callable:
from vision_agent.tools import ImageQuestionAnswering

return lambda x: ImageQuestionAnswering()(**{"prompt": question, "image": x})


class AzureOpenAILMM(OpenAILMM):
def __init__(
Expand Down
6 changes: 4 additions & 2 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
GroundingDINO,
GroundingSAM,
ImageCaption,
ZeroShotCounting,
VisualPromptCounting,
VisualQuestionAnswering,
ImageQuestionAnswering,
SegArea,
SegIoU,
Tool,
VisualPromptCounting,
ZeroShotCounting,
register_tool,
)
145 changes: 142 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from vision_agent.tools.video import extract_frames_from_video
from vision_agent.type_defs import LandingaiAPIKey
from vision_agent.lmm import OpenAILMM

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
Expand Down Expand Up @@ -502,7 +503,7 @@ class ZeroShotCounting(Tool):

# TODO: Add support for input multiple images, which aligns with the output type.
def __call__(self, image: Union[str, ImageType]) -> Dict:
"""Invoke the Image captioning model.
"""Invoke the Zero shot counting model.
Parameters:
image: the input image.
Expand Down Expand Up @@ -566,7 +567,7 @@ class VisualPromptCounting(Tool):

# TODO: Add support for input multiple images, which aligns with the output type.
def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
"""Invoke the Image captioning model.
"""Invoke the few shot counting model.
Parameters:
image: the input image.
Expand All @@ -587,6 +588,144 @@ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
return _send_inference_request(data, "tools")


class VisualQuestionAnswering(Tool):
r"""VisualQuestionAnswering is a tool that can explain contents of an image and answer questions about the same
Example
-------
>>> import vision_agent as va
>>> vqa_tool = va.tools.VisualQuestionAnswering()
>>> vqa_tool(image="image1.jpg", prompt="describe this image in detail")
{'text': "The image contains a cat sitting on a table with a bowl of milk."}
"""

name = "visual_question_answering_"
description = "'visual_question_answering_' is a tool that can describe the contents of the image and it can also answer basic questions about the image."

usage = {
"required_parameters": [
{"name": "image", "type": "str"},
{"name": "prompt", "type": "str"},
],
"examples": [
{
"scenario": "Describe this image in detail. Image name: cat.jpg",
"parameters": {
"image": "cats.jpg",
"prompt": "Describe this image in detail",
},
},
{
"scenario": "Can you help me with this street sign in this image ? What does it say ? Image name: sign.jpg",
"parameters": {
"image": "sign.jpg",
"prompt": "Can you help me with this street sign ? What does it say ?",
},
},
{
"scenario": "Describe the weather in the image for me ? Image name: weather.jpg",
"parameters": {
"image": "weather.jpg",
"prompt": "Describe the weather in the image for me ",
},
},
{
"scenario": "Which 2 are the least frequent bins in this histogram ? Image name: chart.jpg",
"parameters": {
"image": "chart.jpg",
"prompt": "Which 2 are the least frequent bins in this histogram",
},
},
],
}

def __call__(self, image: str, prompt: str) -> Dict:
"""Invoke the visual question answering model.
Parameters:
image: the input image.
Returns:
A dictionary containing the key 'text' and the answer to the prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'}
"""

gpt = OpenAILMM()
return {"text": gpt(input=prompt, images=[image])}


class ImageQuestionAnswering(Tool):
r"""ImageQuestionAnswering is a tool that can explain contents of an image and answer questions about the same
It is same as VisualQuestionAnswering but this tool is not used by agents. It is used when user requests a tool for VQA using generate_image_qa_tool function.
It is also useful if the user wants the data to be not exposed to OpenAI endpoints
Example
-------
>>> import vision_agent as va
>>> vqa_tool = va.tools.ImageQuestionAnswering()
>>> vqa_tool(image="image1.jpg", prompt="describe this image in detail")
{'text': "The image contains a cat sitting on a table with a bowl of milk."}
"""

name = "image_question_answering_"
description = "'image_question_answering_' is a tool that can describe the contents of the image and it can also answer basic questions about the image."

usage = {
"required_parameters": [
{"name": "image", "type": "str"},
{"name": "prompt", "type": "str"},
],
"examples": [
{
"scenario": "Describe this image in detail. Image name: cat.jpg",
"parameters": {
"image": "cats.jpg",
"prompt": "Describe this image in detail",
},
},
{
"scenario": "Can you help me with this street sign in this image ? What does it say ? Image name: sign.jpg",
"parameters": {
"image": "sign.jpg",
"prompt": "Can you help me with this street sign ? What does it say ?",
},
},
{
"scenario": "Describe the weather in the image for me ? Image name: weather.jpg",
"parameters": {
"image": "weather.jpg",
"prompt": "Describe the weather in the image for me ",
},
},
{
"scenario": "Can you generate an image question answering tool ? Image name: chart.jpg, prompt: Which 2 are the least frequent bins in this histogram",
"parameters": {
"image": "chart.jpg",
"prompt": "Which 2 are the least frequent bins in this histogram",
},
},
],
}

def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
"""Invoke the visual question answering model.
Parameters:
image: the input image.
Returns:
A dictionary containing the key 'text' and the answer to the prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'}
"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"prompt": prompt,
"tool": "image_question_answering",
}

return _send_inference_request(data, "tools")


class Crop(Tool):
r"""Crop crops an image given a bounding box and returns a file name of the cropped image."""

Expand Down Expand Up @@ -944,11 +1083,11 @@ def __call__(self, equation: str) -> float:
[
NoOp,
CLIP,
ImageCaption,
GroundingDINO,
AgentGroundingSAM,
ZeroShotCounting,
VisualPromptCounting,
VisualQuestionAnswering,
AgentDINOv,
ExtractFrames,
Crop,
Expand Down

0 comments on commit 87b2193

Please sign in to comment.