From 91d046e399c2760bad5dc7c3101f1ea9338ed072 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 6 Mar 2024 13:39:49 -0800 Subject: [PATCH 1/5] adding basic tools --- vision_agent/lmm/lmm.py | 51 ++++++++++++++++++++++++++++++++++ vision_agent/tools/__init__.py | 2 ++ vision_agent/tools/prompts.py | 43 ++++++++++++++++++++++++++++ vision_agent/tools/tools.py | 22 +++++++++++++++ 4 files changed, 118 insertions(+) create mode 100644 vision_agent/tools/__init__.py create mode 100644 vision_agent/tools/prompts.py create mode 100644 vision_agent/tools/tools.py diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index bdd4cc52..44a7fa5d 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -1,4 +1,5 @@ import base64 +import json import logging from abc import ABC, abstractmethod from pathlib import Path @@ -6,6 +7,17 @@ import requests +from vision_agent.tools import ( + SYSTEM_PROMPT, + CHOOSE_PARAMS, + GROUNDING_DINO, + GROUNDING_SAM, + CLIP, + Classifier, + Detector, + Segmentor, +) + logging.basicConfig(level=logging.INFO) _LOGGER = logging.getLogger(__name__) @@ -90,6 +102,45 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ) return cast(str, response.choices[0].message.content) + def generate_classifier(self, prompt: str) -> Classifier: + prompt = CHOOSE_PARAMS.format(api_doc=CLIP, question=prompt) + response = self.client.chat.completions.create( + model="gpt-4-turbo-preview", # no need to use vision model here + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ], + ) + prompt = json.loads(response.choices[0].message.content)["prompt"] + return Classifier(prompt) + + def generate_detector(self, prompt: str) -> Detector: + prompt = CHOOSE_PARAMS.format(api_doc=GROUNDING_DINO, question=prompt) + response = self.client.chat.completions.create( + model="gpt-4-turbo-preview", # no need to use vision model here + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ], + ) + prompt = json.loads(response.choices[0].message.content)["prompt"] + return Detector(prompt) + + def generate_segmentor(self, prompt: str) -> Segmentor: + prompt = CHOOSE_PARAMS.format(api_doc=GROUNDING_SAM, question=prompt) + response = self.client.chat.completions.create( + model="gpt-4-turbo-preview", # no need to use vision model here + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ], + ) + prompt = json.loads(response.choices[0].message.content)["prompt"] + return Segmentor(prompt) + def get_lmm(name: str) -> LMM: if name == "openai": diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py new file mode 100644 index 00000000..600f287c --- /dev/null +++ b/vision_agent/tools/__init__.py @@ -0,0 +1,2 @@ +from .prompts import SYSTEM_PROMPT, CHOOSE_PARAMS, GROUNDING_DINO, GROUNDING_SAM, CLIP +from .tools import Classifier, Detector, Segmentor diff --git a/vision_agent/tools/prompts.py b/vision_agent/tools/prompts.py new file mode 100644 index 00000000..76bd7fb2 --- /dev/null +++ b/vision_agent/tools/prompts.py @@ -0,0 +1,43 @@ +SYSTEM_PROMPT = "You are a helpful assistant." + +GROUNDING_DINO = ( + "Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." + "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" + 'Example 1: User Question: "Can you build me a car detector?" {{"Parameters":{{"prompt": "car"}}}}\n' + 'Example 2: User Question: "Can you detect the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n' + 'Exmaple 3: User Question: "Can you build me a tool that detects red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n' +) + +GROUNDING_SAM = ( + "Grounding SAM is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." + "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" + 'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": "car"}}}}\n' + 'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n' + 'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n' +) + +CLIP = ( + "CLIP is a tool that can classify or tag any image given a set if input classes or tags." + "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" + 'Example 1: User Question: "Can you classify this image as a cat?" {{"Parameters":{{"prompt": ["cat"]}}}}\n' + 'Example 2: User Question: "Can you tag this photograph with cat or dog?" {{"Parameters":{{"prompt": ["cat", "dog"]}}}}\n' + 'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n' +) + +# EasyTool prompts +CHOOSE_PARAMS = ( + "This is an API tool documentation. Given a user's question, you need to output parameters according to the API tool documentation to successfully call the API to solve the user's question.\n" + "This is the API tool documentation: {api_doc}\n" + "Please note that: \n" + "1. The Example in the API tool documentation can help you better understand the use of the API.\n" + '2. Ensure the parameters you output are correct. The output must contain the required parameters, and can contain the optional parameters based on the question. If no paremters in the required parameters and optional parameters, just leave it as {{"Parameters":{{}}}}\n' + "3. If the user's question mentions other APIs, you should ONLY consider the API tool documentation I give and do not consider other APIs.\n" + '4. If you need to use this API multiple times, please set "Parameters" to a list.\n' + "5. You must ONLY output in a parsible JSON format. Two examples output looks like:\n" + "'''\n" + 'Example 1: {{"Parameters":{{"keyword": "Artificial Intelligence", "language": "English"}}}}\n' + 'Example 2: {{"Parameters":[{{"keyword": "Artificial Intelligence", "language": "English"}}, {{"keyword": "Machine Learning", "language": "English"}}]}}\n' + "'''\n" + "This is user's question: {question}\n" + "Output:\n" +) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py new file mode 100644 index 00000000..8f3dfd50 --- /dev/null +++ b/vision_agent/tools/tools.py @@ -0,0 +1,22 @@ +class Classifier: + def __init__(self, prompt: str): + self.prompt = prompt + + def __call__(self: image: Union[str, Image]) -> List[Dict]: + raise NotImplementedError + + +class Detector: + def __init__(self, prompt: str): + self.prompt = prompt + + def __call__(self: image: Union[str, Image]) -> List[Dict]: + raise NotImplementedError + + +class Segmentor: + def __init__(self, prompt: str): + self.prompt = prompt + + def __call__(self: image: Union[str, Image]) -> List[Dict]: + raise NotImplementedError From 4767d6df50b97c1fa20f331ff81e0a8bdfab9fbf Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 6 Mar 2024 15:15:56 -0800 Subject: [PATCH 2/5] fix type errors --- vision_agent/tools/tools.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 8f3dfd50..bd640e76 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,8 +1,13 @@ +from typing import Dict, List, Union + +from PIL.Image import Image as ImageType + + class Classifier: def __init__(self, prompt: str): self.prompt = prompt - def __call__(self: image: Union[str, Image]) -> List[Dict]: + def __call__(self, image: Union[str, ImageType]) -> List[Dict]: raise NotImplementedError @@ -10,7 +15,7 @@ class Detector: def __init__(self, prompt: str): self.prompt = prompt - def __call__(self: image: Union[str, Image]) -> List[Dict]: + def __call__(self, image: Union[str, ImageType]) -> List[Dict]: raise NotImplementedError @@ -18,5 +23,5 @@ class Segmentor: def __init__(self, prompt: str): self.prompt = prompt - def __call__(self: image: Union[str, Image]) -> List[Dict]: + def __call__(self, image: Union[str, ImageType]) -> List[Dict]: raise NotImplementedError From b059ea1790bb97e427de6eae44214fd850cd20db Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 6 Mar 2024 15:37:16 -0800 Subject: [PATCH 3/5] reformat classes --- vision_agent/lmm/lmm.py | 26 +++++++++++------------- vision_agent/tools/__init__.py | 4 ++-- vision_agent/tools/prompts.py | 24 ---------------------- vision_agent/tools/tools.py | 37 +++++++++++++++++++++++++++++++--- 4 files changed, 48 insertions(+), 43 deletions(-) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 44a7fa5d..8992072d 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -10,12 +10,10 @@ from vision_agent.tools import ( SYSTEM_PROMPT, CHOOSE_PARAMS, - GROUNDING_DINO, - GROUNDING_SAM, + ImageTool, CLIP, - Classifier, - Detector, - Segmentor, + GroundingDINO, + GroundingSAM, ) logging.basicConfig(level=logging.INFO) @@ -102,8 +100,8 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ) return cast(str, response.choices[0].message.content) - def generate_classifier(self, prompt: str) -> Classifier: - prompt = CHOOSE_PARAMS.format(api_doc=CLIP, question=prompt) + def generate_classifier(self, prompt: str) -> ImageTool: + prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt) response = self.client.chat.completions.create( model="gpt-4-turbo-preview", # no need to use vision model here response_format={"type": "json_object"}, @@ -113,10 +111,10 @@ def generate_classifier(self, prompt: str) -> Classifier: ], ) prompt = json.loads(response.choices[0].message.content)["prompt"] - return Classifier(prompt) + return CLIP(prompt) - def generate_detector(self, prompt: str) -> Detector: - prompt = CHOOSE_PARAMS.format(api_doc=GROUNDING_DINO, question=prompt) + def generate_detector(self, prompt: str) -> ImageTool: + prompt = CHOOSE_PARAMS.format(api_doc=GroundingDINO.doc, question=prompt) response = self.client.chat.completions.create( model="gpt-4-turbo-preview", # no need to use vision model here response_format={"type": "json_object"}, @@ -126,10 +124,10 @@ def generate_detector(self, prompt: str) -> Detector: ], ) prompt = json.loads(response.choices[0].message.content)["prompt"] - return Detector(prompt) + return GroundingDINO(prompt) - def generate_segmentor(self, prompt: str) -> Segmentor: - prompt = CHOOSE_PARAMS.format(api_doc=GROUNDING_SAM, question=prompt) + def generate_segmentor(self, prompt: str) -> ImageTool: + prompt = CHOOSE_PARAMS.format(api_doc=GroundingSAM.doc, question=prompt) response = self.client.chat.completions.create( model="gpt-4-turbo-preview", # no need to use vision model here response_format={"type": "json_object"}, @@ -139,7 +137,7 @@ def generate_segmentor(self, prompt: str) -> Segmentor: ], ) prompt = json.loads(response.choices[0].message.content)["prompt"] - return Segmentor(prompt) + return GroundingSAM(prompt) def get_lmm(name: str) -> LMM: diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 600f287c..eedfb21d 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,2 +1,2 @@ -from .prompts import SYSTEM_PROMPT, CHOOSE_PARAMS, GROUNDING_DINO, GROUNDING_SAM, CLIP -from .tools import Classifier, Detector, Segmentor +from .prompts import SYSTEM_PROMPT, CHOOSE_PARAMS +from .tools import ImageTool, CLIP, GroundingDINO, GroundingSAM diff --git a/vision_agent/tools/prompts.py b/vision_agent/tools/prompts.py index 76bd7fb2..0488c3f2 100644 --- a/vision_agent/tools/prompts.py +++ b/vision_agent/tools/prompts.py @@ -1,29 +1,5 @@ SYSTEM_PROMPT = "You are a helpful assistant." -GROUNDING_DINO = ( - "Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." - "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - 'Example 1: User Question: "Can you build me a car detector?" {{"Parameters":{{"prompt": "car"}}}}\n' - 'Example 2: User Question: "Can you detect the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n' - 'Exmaple 3: User Question: "Can you build me a tool that detects red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n' -) - -GROUNDING_SAM = ( - "Grounding SAM is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." - "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - 'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": "car"}}}}\n' - 'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n' - 'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n' -) - -CLIP = ( - "CLIP is a tool that can classify or tag any image given a set if input classes or tags." - "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - 'Example 1: User Question: "Can you classify this image as a cat?" {{"Parameters":{{"prompt": ["cat"]}}}}\n' - 'Example 2: User Question: "Can you tag this photograph with cat or dog?" {{"Parameters":{{"prompt": ["cat", "dog"]}}}}\n' - 'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n' -) - # EasyTool prompts CHOOSE_PARAMS = ( "This is an API tool documentation. Given a user's question, you need to output parameters according to the API tool documentation to successfully call the API to solve the user's question.\n" diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index bd640e76..9ca70452 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,9 +1,24 @@ from typing import Dict, List, Union +from abc import ABC, abstractmethod from PIL.Image import Image as ImageType -class Classifier: +class ImageTool(ABC): + @abstractmethod + def __call__(self, image: Union[str, ImageType]) -> List[Dict]: + pass + + +class CLIP(ImageTool): + doc = ( + "CLIP is a tool that can classify or tag any image given a set if input classes or tags." + "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" + 'Example 1: User Question: "Can you classify this image as a cat?" {{"Parameters":{{"prompt": ["cat"]}}}}\n' + 'Example 2: User Question: "Can you tag this photograph with cat or dog?" {{"Parameters":{{"prompt": ["cat", "dog"]}}}}\n' + 'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n' + ) + def __init__(self, prompt: str): self.prompt = prompt @@ -11,7 +26,15 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]: raise NotImplementedError -class Detector: +class GroundingDINO(ImageTool): + doc = ( + "Grounding DINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." + "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" + 'Example 1: User Question: "Can you build me a car detector?" {{"Parameters":{{"prompt": "car"}}}}\n' + 'Example 2: User Question: "Can you detect the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n' + 'Exmaple 3: User Question: "Can you build me a tool that detects red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n' + ) + def __init__(self, prompt: str): self.prompt = prompt @@ -19,7 +42,15 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]: raise NotImplementedError -class Segmentor: +class GroundingSAM(ImageTool): + doc = ( + "Grounding SAM is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." + "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" + 'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": "car"}}}}\n' + 'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n' + 'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n' + ) + def __init__(self, prompt: str): self.prompt = prompt From 27d0a5a75f8b57d2872e39b079cf5bbe20fef2fe Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 6 Mar 2024 16:39:54 -0800 Subject: [PATCH 4/5] add better error handling for json decoding --- vision_agent/lmm/lmm.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 8992072d..c1e402f7 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -110,7 +110,13 @@ def generate_classifier(self, prompt: str) -> ImageTool: {"role": "user", "content": prompt}, ], ) - prompt = json.loads(response.choices[0].message.content)["prompt"] + + try: + prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"] + except json.JSONDecodeError: + _LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}") + raise ValueError("Failed to decode response") + return CLIP(prompt) def generate_detector(self, prompt: str) -> ImageTool: @@ -123,7 +129,13 @@ def generate_detector(self, prompt: str) -> ImageTool: {"role": "user", "content": prompt}, ], ) - prompt = json.loads(response.choices[0].message.content)["prompt"] + + try: + prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"] + except json.JSONDecodeError: + _LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}") + raise ValueError("Failed to decode response") + return GroundingDINO(prompt) def generate_segmentor(self, prompt: str) -> ImageTool: @@ -136,7 +148,13 @@ def generate_segmentor(self, prompt: str) -> ImageTool: {"role": "user", "content": prompt}, ], ) - prompt = json.loads(response.choices[0].message.content)["prompt"] + + try: + prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"] + except json.JSONDecodeError: + _LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}") + raise ValueError("Failed to decode response") + return GroundingSAM(prompt) From d39d3d239e9762d8a6d97c015e6fbaa94d3559c6 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 6 Mar 2024 16:49:27 -0800 Subject: [PATCH 5/5] fix formatting --- vision_agent/lmm/lmm.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index c1e402f7..488048fc 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -112,9 +112,13 @@ def generate_classifier(self, prompt: str) -> ImageTool: ) try: - prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"] + prompt = json.loads(cast(str, response.choices[0].message.content))[ + "prompt" + ] except json.JSONDecodeError: - _LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}") + _LOGGER.error( + f"Failed to decode response: {response.choices[0].message.content}" + ) raise ValueError("Failed to decode response") return CLIP(prompt) @@ -131,9 +135,13 @@ def generate_detector(self, prompt: str) -> ImageTool: ) try: - prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"] + prompt = json.loads(cast(str, response.choices[0].message.content))[ + "prompt" + ] except json.JSONDecodeError: - _LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}") + _LOGGER.error( + f"Failed to decode response: {response.choices[0].message.content}" + ) raise ValueError("Failed to decode response") return GroundingDINO(prompt) @@ -150,9 +158,13 @@ def generate_segmentor(self, prompt: str) -> ImageTool: ) try: - prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"] + prompt = json.loads(cast(str, response.choices[0].message.content))[ + "prompt" + ] except json.JSONDecodeError: - _LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}") + _LOGGER.error( + f"Failed to decode response: {response.choices[0].message.content}" + ) raise ValueError("Failed to decode response") return GroundingSAM(prompt)