Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding vqa tool for agent and also for general purposes #63

Merged
merged 5 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def generate_zero_shot_counter(self, question: str) -> Callable:
return lambda x: ZeroShotCounting()(**{"image": x})


def generate_image_qa_tool(self, question: str) -> Callable:
from vision_agent.tools import ImageQuestionAnswering

return lambda x: ImageQuestionAnswering()(**{"prompt": question, "image": x})


class AzureOpenAILLM(OpenAILLM):
def __init__(
self,
Expand Down
17 changes: 13 additions & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@

from vision_agent.tools import (
CHOOSE_PARAMS,
CLIP,
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ZeroShotCounting,
)

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -205,6 +201,8 @@ def generate(
return cast(str, response.choices[0].message.content)

def generate_classifier(self, question: str) -> Callable:
from vision_agent.tools import CLIP

api_doc = CLIP.description + "\n" + str(CLIP.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
Expand All @@ -228,6 +226,8 @@ def generate_classifier(self, question: str) -> Callable:
return lambda x: CLIP()(**{"prompt": params["prompt"], "image": x})

def generate_detector(self, question: str) -> Callable:
from vision_agent.tools import GroundingDINO

api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
Expand All @@ -251,6 +251,8 @@ def generate_detector(self, question: str) -> Callable:
return lambda x: GroundingDINO()(**{"prompt": params["prompt"], "image": x})

def generate_segmentor(self, question: str) -> Callable:
from vision_agent.tools import GroundingSAM

api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
Expand All @@ -274,8 +276,15 @@ def generate_segmentor(self, question: str) -> Callable:
return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})

def generate_zero_shot_counter(self, question: str) -> Callable:
from vision_agent.tools import ZeroShotCounting

return lambda x: ZeroShotCounting()(**{"image": x})

def generate_image_qa_tool(self, question: str) -> Callable:
from vision_agent.tools import ImageQuestionAnswering

return lambda x: ImageQuestionAnswering()(**{"prompt": question, "image": x})


class AzureOpenAILMM(OpenAILMM):
def __init__(
Expand Down
6 changes: 4 additions & 2 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
GroundingDINO,
GroundingSAM,
ImageCaption,
ZeroShotCounting,
VisualPromptCounting,
VisualQuestionAnswering,
ImageQuestionAnswering,
SegArea,
SegIoU,
Tool,
VisualPromptCounting,
ZeroShotCounting,
register_tool,
)
145 changes: 142 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from vision_agent.tools.video import extract_frames_from_video
from vision_agent.type_defs import LandingaiAPIKey
from vision_agent.lmm import OpenAILMM

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
Expand Down Expand Up @@ -502,7 +503,7 @@ class ZeroShotCounting(Tool):

# TODO: Add support for input multiple images, which aligns with the output type.
def __call__(self, image: Union[str, ImageType]) -> Dict:
"""Invoke the Image captioning model.
"""Invoke the Zero shot counting model.

Parameters:
image: the input image.
Expand Down Expand Up @@ -566,7 +567,7 @@ class VisualPromptCounting(Tool):

# TODO: Add support for input multiple images, which aligns with the output type.
def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
"""Invoke the Image captioning model.
"""Invoke the few shot counting model.

Parameters:
image: the input image.
Expand All @@ -587,6 +588,144 @@ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
return _send_inference_request(data, "tools")


class VisualQuestionAnswering(Tool):
r"""VisualQuestionAnswering is a tool that can explain contents of an image and answer questions about the same

Example
-------
>>> import vision_agent as va
>>> vqa_tool = va.tools.VisualQuestionAnswering()
>>> vqa_tool(image="image1.jpg", prompt="describe this image in detail")
{'text': "The image contains a cat sitting on a table with a bowl of milk."}
"""

name = "visual_question_answering_"
description = "'visual_question_answering_' is a tool that can describe the contents of the image and it can also answer basic questions about the image."

usage = {
"required_parameters": [
{"name": "image", "type": "str"},
{"name": "prompt", "type": "str"},
],
"examples": [
{
"scenario": "Describe this image in detail. Image name: cat.jpg",
"parameters": {
"image": "cats.jpg",
"prompt": "Describe this image in detail",
},
},
{
"scenario": "Can you help me with this street sign in this image ? What does it say ? Image name: sign.jpg",
"parameters": {
"image": "sign.jpg",
"prompt": "Can you help me with this street sign ? What does it say ?",
},
},
{
"scenario": "Describe the weather in the image for me ? Image name: weather.jpg",
"parameters": {
"image": "weather.jpg",
"prompt": "Describe the weather in the image for me ",
},
},
{
"scenario": "Which 2 are the least frequent bins in this histogram ? Image name: chart.jpg",
"parameters": {
"image": "chart.jpg",
"prompt": "Which 2 are the least frequent bins in this histogram",
},
},
],
}

def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
"""Invoke the visual question answering model.

Parameters:
image: the input image.

Returns:
A dictionary containing the key 'text' and the answer to the prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'}
"""

gpt = OpenAILMM()
return gpt(input=prompt, images=[image])


class ImageQuestionAnswering(Tool):
r"""ImageQuestionAnswering is a tool that can explain contents of an image and answer questions about the same
It is same as VisualQuestionAnswering but this tool is not used by agents. It is used when user requests a tool for VQA using generate_image_qa_tool function.
It is also useful if the user wants the data to be not exposed to OpenAI endpoints

Example
-------
>>> import vision_agent as va
>>> vqa_tool = va.tools.ImageQuestionAnswering()
>>> vqa_tool(image="image1.jpg", prompt="describe this image in detail")
{'text': "The image contains a cat sitting on a table with a bowl of milk."}
"""

name = "image_question_answering_"
description = "'image_question_answering_' is a tool that can describe the contents of the image and it can also answer basic questions about the image."

usage = {
"required_parameters": [
{"name": "image", "type": "str"},
{"name": "prompt", "type": "str"},
],
"examples": [
{
"scenario": "Describe this image in detail. Image name: cat.jpg",
"parameters": {
"image": "cats.jpg",
"prompt": "Describe this image in detail",
},
},
{
"scenario": "Can you help me with this street sign in this image ? What does it say ? Image name: sign.jpg",
"parameters": {
"image": "sign.jpg",
"prompt": "Can you help me with this street sign ? What does it say ?",
},
},
{
"scenario": "Describe the weather in the image for me ? Image name: weather.jpg",
"parameters": {
"image": "weather.jpg",
"prompt": "Describe the weather in the image for me ",
},
},
{
"scenario": "Can you generate an image question answering tool ? Image name: chart.jpg, prompt: Which 2 are the least frequent bins in this histogram",
"parameters": {
"image": "chart.jpg",
"prompt": "Which 2 are the least frequent bins in this histogram",
},
},
],
}

def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
"""Invoke the visual question answering model.

Parameters:
image: the input image.

Returns:
A dictionary containing the key 'text' and the answer to the prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'}
"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"prompt": prompt,
"tool": "image_question_answering",
}

return _send_inference_request(data, "tools")


class Crop(Tool):
r"""Crop crops an image given a bounding box and returns a file name of the cropped image."""

Expand Down Expand Up @@ -944,11 +1083,11 @@ def __call__(self, equation: str) -> float:
[
NoOp,
CLIP,
ImageCaption,
GroundingDINO,
AgentGroundingSAM,
ZeroShotCounting,
VisualPromptCounting,
VisualQuestionAnswering,
AgentDINOv,
ExtractFrames,
Crop,
Expand Down
Loading