Skip to content

Commit

Permalink
updated tools
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 20, 2024
1 parent b36f384 commit f038363
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 73 deletions.
32 changes: 17 additions & 15 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from abc import ABC, abstractmethod
from typing import Dict, List, Mapping, Union, cast
from typing import Callable, Dict, List, Mapping, Union, cast

from openai import OpenAI

Expand All @@ -10,7 +10,6 @@
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ImageTool,
)


Expand Down Expand Up @@ -65,9 +64,9 @@ def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str:
return self.generate(input)
return self.chat(input)

def generate_classifier(self, prompt: str) -> ImageTool:
def generate_classifier(self, question: str) -> Callable:
api_doc = CLIP.description + "\n" + str(CLIP.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=prompt)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
Expand All @@ -80,38 +79,41 @@ def generate_classifier(self, prompt: str) -> ImageTool:
params = json.loads(cast(str, response.choices[0].message.content))[
"Parameters"
]
return CLIP(**cast(Mapping, params))

def generate_detector(self, params: str) -> ImageTool:
return lambda x: CLIP()(**{"prompt": params["prompt"], "image": x})

def generate_detector(self, question: str) -> Callable:
api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage)
params = CHOOSE_PARAMS.format(api_doc=api_doc, question=params)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": params},
{"role": "user", "content": prompt},
],
)

params = json.loads(cast(str, response.choices[0].message.content))[
params: Mapping = json.loads(cast(str, response.choices[0].message.content))[
"Parameters"
]
return GroundingDINO(**cast(Mapping, params))

def generate_segmentor(self, params: str) -> ImageTool:
return lambda x: GroundingDINO()(**{"prompt": params["prompt"], "image": x})

def generate_segmentor(self, question: str) -> Callable:
api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage)
params = CHOOSE_PARAMS.format(api_doc=api_doc, question=params)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": params},
{"role": "user", "content": prompt},
],
)

params = json.loads(cast(str, response.choices[0].message.content))[
params: Mapping = json.loads(cast(str, response.choices[0].message.content))[
"Parameters"
]
return GroundingSAM(**cast(Mapping, params))

return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})
27 changes: 13 additions & 14 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union, cast
from typing import Any, Callable, Dict, List, Optional, Union, cast

import requests
from openai import OpenAI
Expand All @@ -14,7 +14,6 @@
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ImageTool,
)

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -168,9 +167,9 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
)
return cast(str, response.choices[0].message.content)

def generate_classifier(self, prompt: str) -> ImageTool:
def generate_classifier(self, question: str) -> Callable:
api_doc = CLIP.description + "\n" + str(CLIP.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=prompt)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
Expand All @@ -180,7 +179,7 @@ def generate_classifier(self, prompt: str) -> ImageTool:
)

try:
prompt = json.loads(cast(str, response.choices[0].message.content))[
params = json.loads(cast(str, response.choices[0].message.content))[
"Parameters"
]
except json.JSONDecodeError:
Expand All @@ -189,16 +188,16 @@ def generate_classifier(self, prompt: str) -> ImageTool:
)
raise ValueError("Failed to decode response")

return CLIP(**cast(Mapping, prompt))
return lambda x: CLIP()(**{"prompt": params["prompt"], "image": x})

def generate_detector(self, params: str) -> ImageTool:
def generate_detector(self, question: str) -> Callable:
api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage)
params = CHOOSE_PARAMS.format(api_doc=api_doc, question=params)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": params},
{"role": "user", "content": prompt},
],
)

Expand All @@ -212,11 +211,11 @@ def generate_detector(self, params: str) -> ImageTool:
)
raise ValueError("Failed to decode response")

return GroundingDINO(**cast(Mapping, params))
return lambda x: GroundingDINO()(**{"prompt": params["prompt"], "image": x})

def generate_segmentor(self, prompt: str) -> ImageTool:
def generate_segmentor(self, question: str) -> Callable:
api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=prompt)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question)
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
Expand All @@ -226,7 +225,7 @@ def generate_segmentor(self, prompt: str) -> ImageTool:
)

try:
prompt = json.loads(cast(str, response.choices[0].message.content))[
params = json.loads(cast(str, response.choices[0].message.content))[
"Parameters"
]
except json.JSONDecodeError:
Expand All @@ -235,7 +234,7 @@ def generate_segmentor(self, prompt: str) -> ImageTool:
)
raise ValueError("Failed to decode response")

return GroundingSAM(**cast(Mapping, prompt))
return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})


def get_lmm(name: str) -> LMM:
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import CLIP, TOOLS, GroundingDINO, GroundingSAM, ImageTool
from .tools import CLIP, TOOLS, GroundingDINO, GroundingSAM, Tool
102 changes: 59 additions & 43 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from abc import ABC, abstractmethod
from abc import ABC
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union, cast

Expand All @@ -17,10 +17,10 @@ def normalize_bbox(
) -> List[float]:
r"""Normalize the bounding box coordinates to be between 0 and 1."""
x1, y1, x2, y2 = bbox
x1 = x1 / image_size[1]
y1 = y1 / image_size[0]
x2 = x2 / image_size[1]
y2 = y2 / image_size[0]
x1 = round(x1 / image_size[1], 2)
y1 = round(y1 / image_size[0], 2)
x2 = round(x2 / image_size[1], 2)
y2 = round(y2 / image_size[0], 2)
return [x1, y1, x2, y2]


Expand All @@ -47,13 +47,7 @@ class Tool(ABC):
usage: Dict


class ImageTool(Tool):
@abstractmethod
def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
pass


class CLIP(ImageTool):
class CLIP(Tool):
r"""CLIP is a tool that can classify or tag any image given a set if input classes
or tags.
Expand All @@ -70,19 +64,32 @@ class CLIP(ImageTool):
description = (
"'clip_' is a tool that can classify or tag any image given a set if input classes or tags."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you classify this image as a cat?" {{"Parameters":{{"prompt": ["cat"]}}}}\n'
'Example 2: User Question: "Can you tag this photograph with cat or dog?" {{"Parameters":{{"prompt": ["cat", "dog"]}}}}\n'
'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n'
)
usage: Dict = {}

def __init__(self, prompt: list[str]):
self.prompt = prompt
usage = {
"required_parameters": [{"name": "prompt", "type": "List[str]"}, {"name": "image", "type": "str"}],
"examples": [
{
"scenario": "Can you classify this image as a cat? Image name: cat.jpg",
"parameters": {"prompt": ["cat"], "image": "cat.jpg"},
},
{
"scenario": "Can you tag this photograph with cat or dog? Image name: cat_dog.jpg",
"parameters": {"prompt": ["cat", "dog"], "image": "cat_dog.jpg"},
},
{
"scenario": "Can you build me a classifier that classifies red shirts, green shirts and other? Image name: shirts.jpg",
"parameters": {
"prompt": ["red shirt", "green shirt", "other"],
"image": "shirts.jpg",
},
},
],
}

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
image_b64 = convert_to_b64(image)
data = {
"classes": self.prompt,
"classes": prompt,
"images": [image_b64],
}
res = requests.post(
Expand All @@ -99,7 +106,7 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
return cast(List[Dict], resp_json["data"])


class GroundingDINO(ImageTool):
class GroundingDINO(Tool):
_ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws"

name = "grounding_dino_"
Expand All @@ -113,31 +120,28 @@ class GroundingDINO(ImageTool):
'An example output would be: [{"label": ["car"], "score": [0.99], "bbox": [[0.1, 0.2, 0.3, 0.4]]}]\n'
)
usage = {
"required_parameters": {"name": "prompt", "type": "str"},
"required_parameters": [{"name": "prompt", "type": "str"}, {"name": "image", "type": "str"}],
"examples": [
{
"scenario": "Can you build me a car detector?",
"parameters": {"prompt": "car"},
"parameters": {"prompt": "car", "image": ""},
},
{
"scenario": "Can you detect the person on the left?",
"parameters": {"prompt": "person on the left"},
"scenario": "Can you detect the person on the left? Image name: person.jpg",
"parameters": {"prompt": "person on the left", "image": "person.jpg"},
},
{
"scenario": "Detect the red shirts and green shirst.",
"parameters": {"prompt": "red shirt. green shirt"},
"scenario": "Detect the red shirts and green shirst. Image name: shirts.jpg",
"parameters": {"prompt": "red shirt. green shirt", "image": "shirts.jpg"},
},
],
}

def __init__(self, prompt: str):
self.prompt = prompt

def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]:
def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict]:
image_size = get_image_size(image)
image_b64 = convert_to_b64(image)
data = {
"prompt": self.prompt,
"prompt": prompt,
"images": [image_b64],
}
res = requests.post(
Expand All @@ -157,10 +161,12 @@ def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]:
elt["bboxes"] = [
normalize_bbox(box, image_size) for box in elt["bboxes"]
]
if "scores" in elt:
elt["scores"] = [round(score, 2) for score in elt["scores"]]
return cast(List[Dict], resp_data)


class GroundingSAM(ImageTool):
class GroundingSAM(Tool):
r"""Grounding SAM is a tool that can detect and segment arbitrary objects with
inputs such as category names or referring expressions.
Expand All @@ -185,19 +191,29 @@ class GroundingSAM(ImageTool):
description = (
"'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": ["car"]}}}}\n'
'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": ["person on the left"]}}\n'
'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": ["red shirt", "green shirt"]}}}}\n'
)
usage: Dict = {}

def __init__(self, prompt: list[str]):
self.prompt = prompt
usage = {
"required_parameters": [{"name": "prompt", "type": "List[str]"}, {"name": "image", "type": "str"}],
"examples": [
{
"scenario": "Can you build me a car segmentor?",
"parameters": {"prompt": ["car"], "image": ""},
},
{
"scenario": "Can you segment the person on the left? Image name: person.jpg",
"parameters": {"prompt": ["person on the left"], "image": "person.jpg"},
},
{
"scenario": "Can you build me a tool that segments red shirts and green shirts? Image name: shirts.jpg",
"parameters": {"prompt": ["red shirt", "green shirt"], "image": "shirts.jpg"},
},
]
}

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
image_b64 = convert_to_b64(image)
data = {
"classes": self.prompt,
"classes": prompt,
"image": image_b64,
}
res = requests.post(
Expand Down

0 comments on commit f038363

Please sign in to comment.