Skip to content

Commit

Permalink
reformat classes
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 6, 2024
1 parent 4767d6d commit b059ea1
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 43 deletions.
26 changes: 12 additions & 14 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
from vision_agent.tools import (
SYSTEM_PROMPT,
CHOOSE_PARAMS,
GROUNDING_DINO,
GROUNDING_SAM,
ImageTool,
CLIP,
Classifier,
Detector,
Segmentor,
GroundingDINO,
GroundingSAM,
)

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -102,8 +100,8 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
)
return cast(str, response.choices[0].message.content)

def generate_classifier(self, prompt: str) -> Classifier:
prompt = CHOOSE_PARAMS.format(api_doc=CLIP, question=prompt)
def generate_classifier(self, prompt: str) -> ImageTool:
prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt)
response = self.client.chat.completions.create(
model="gpt-4-turbo-preview", # no need to use vision model here
response_format={"type": "json_object"},
Expand All @@ -113,10 +111,10 @@ def generate_classifier(self, prompt: str) -> Classifier:
],
)
prompt = json.loads(response.choices[0].message.content)["prompt"]
return Classifier(prompt)
return CLIP(prompt)

def generate_detector(self, prompt: str) -> Detector:
prompt = CHOOSE_PARAMS.format(api_doc=GROUNDING_DINO, question=prompt)
def generate_detector(self, prompt: str) -> ImageTool:
prompt = CHOOSE_PARAMS.format(api_doc=GroundingDINO.doc, question=prompt)
response = self.client.chat.completions.create(
model="gpt-4-turbo-preview", # no need to use vision model here
response_format={"type": "json_object"},
Expand All @@ -126,10 +124,10 @@ def generate_detector(self, prompt: str) -> Detector:
],
)
prompt = json.loads(response.choices[0].message.content)["prompt"]
return Detector(prompt)
return GroundingDINO(prompt)

def generate_segmentor(self, prompt: str) -> Segmentor:
prompt = CHOOSE_PARAMS.format(api_doc=GROUNDING_SAM, question=prompt)
def generate_segmentor(self, prompt: str) -> ImageTool:
prompt = CHOOSE_PARAMS.format(api_doc=GroundingSAM.doc, question=prompt)
response = self.client.chat.completions.create(
model="gpt-4-turbo-preview", # no need to use vision model here
response_format={"type": "json_object"},
Expand All @@ -139,7 +137,7 @@ def generate_segmentor(self, prompt: str) -> Segmentor:
],
)
prompt = json.loads(response.choices[0].message.content)["prompt"]
return Segmentor(prompt)
return GroundingSAM(prompt)


def get_lmm(name: str) -> LMM:
Expand Down
4 changes: 2 additions & 2 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .prompts import SYSTEM_PROMPT, CHOOSE_PARAMS, GROUNDING_DINO, GROUNDING_SAM, CLIP
from .tools import Classifier, Detector, Segmentor
from .prompts import SYSTEM_PROMPT, CHOOSE_PARAMS
from .tools import ImageTool, CLIP, GroundingDINO, GroundingSAM
24 changes: 0 additions & 24 deletions vision_agent/tools/prompts.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,5 @@
SYSTEM_PROMPT = "You are a helpful assistant."

GROUNDING_DINO = (
"Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you build me a car detector?" {{"Parameters":{{"prompt": "car"}}}}\n'
'Example 2: User Question: "Can you detect the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n'
'Exmaple 3: User Question: "Can you build me a tool that detects red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n'
)

GROUNDING_SAM = (
"Grounding SAM is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": "car"}}}}\n'
'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n'
'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n'
)

CLIP = (
"CLIP is a tool that can classify or tag any image given a set if input classes or tags."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you classify this image as a cat?" {{"Parameters":{{"prompt": ["cat"]}}}}\n'
'Example 2: User Question: "Can you tag this photograph with cat or dog?" {{"Parameters":{{"prompt": ["cat", "dog"]}}}}\n'
'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n'
)

# EasyTool prompts
CHOOSE_PARAMS = (
"This is an API tool documentation. Given a user's question, you need to output parameters according to the API tool documentation to successfully call the API to solve the user's question.\n"
Expand Down
37 changes: 34 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,56 @@
from typing import Dict, List, Union
from abc import ABC, abstractmethod

from PIL.Image import Image as ImageType


class Classifier:
class ImageTool(ABC):
@abstractmethod
def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
pass


class CLIP(ImageTool):
doc = (
"CLIP is a tool that can classify or tag any image given a set if input classes or tags."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you classify this image as a cat?" {{"Parameters":{{"prompt": ["cat"]}}}}\n'
'Example 2: User Question: "Can you tag this photograph with cat or dog?" {{"Parameters":{{"prompt": ["cat", "dog"]}}}}\n'
'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n'
)

def __init__(self, prompt: str):
self.prompt = prompt

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
raise NotImplementedError


class Detector:
class GroundingDINO(ImageTool):
doc = (
"Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you build me a car detector?" {{"Parameters":{{"prompt": "car"}}}}\n'
'Example 2: User Question: "Can you detect the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n'
'Exmaple 3: User Question: "Can you build me a tool that detects red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n'
)

def __init__(self, prompt: str):
self.prompt = prompt

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
raise NotImplementedError


class Segmentor:
class GroundingSAM(ImageTool):
doc = (
"Grounding SAM is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": "car"}}}}\n'
'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n'
'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n'
)

def __init__(self, prompt: str):
self.prompt = prompt

Expand Down

0 comments on commit b059ea1

Please sign in to comment.