diff --git a/vision_agent/__init__.py b/vision_agent/__init__.py index 7e7db515..ccb910f8 100644 --- a/vision_agent/__init__.py +++ b/vision_agent/__init__.py @@ -1,3 +1,4 @@ from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm +from .llm import LLM, OpenAILLM from .emb import Embedder, SentenceTransformerEmb, OpenAIEmb, get_embedder from .data import DataStore, build_data_store diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 86533972..c99b213e 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -1,7 +1,7 @@ import base64 from io import BytesIO from pathlib import Path -from typing import Union +from typing import Union, Tuple import numpy as np from PIL import Image @@ -14,6 +14,16 @@ def b64_to_pil(b64_str: str) -> Image.Image: return Image.open(BytesIO(base64.b64decode(b64_str))) +def get_image_size(data: Union[str, Path, np.ndarray, Image.Image]) -> Tuple[int, ...]: + if isinstance(data, (str, Path)): + data = Image.open(data) + + if isinstance(data, Image.Image): + return data.size[::-1] + else: + return data.shape[:2] + + def convert_to_b64(data: Union[str, Path, np.ndarray, Image.Image]) -> str: if data is None: raise ValueError(f"Invalid input image: {data}. Input image can't be None.") diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 14472884..cfc38311 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -38,8 +38,8 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str class LLaVALMM(LMM): r"""An LMM class for the LLaVA-1.6 34B model.""" - def __init__(self, name: str): - self.name = name + def __init__(self, model_name: str): + self.model_name = model_name def generate( self, @@ -67,10 +67,10 @@ def generate( class OpenAILMM(LMM): r"""An LMM class for the OpenAI GPT-4 Vision model.""" - def __init__(self, name: str): + def __init__(self, model_name: str = "gpt-4-vision-preview"): from openai import OpenAI - self.name = name + self.model_name = model_name self.client = OpenAI() def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: @@ -96,15 +96,14 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ) response = self.client.chat.completions.create( - model="gpt-4-vision-preview", messages=message # type: ignore + model=self.model_name, messages=message # type: ignore ) return cast(str, response.choices[0].message.content) def generate_classifier(self, prompt: str) -> ImageTool: prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt) response = self.client.chat.completions.create( - model="gpt-4-turbo-preview", # no need to use vision model here - response_format={"type": "json_object"}, + model=self.model_name, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, @@ -123,20 +122,19 @@ def generate_classifier(self, prompt: str) -> ImageTool: return CLIP(prompt) - def generate_detector(self, prompt: str) -> ImageTool: - prompt = CHOOSE_PARAMS.format(api_doc=GroundingDINO.doc, question=prompt) + def generate_detector(self, params: str) -> ImageTool: + params = CHOOSE_PARAMS.format(api_doc=GroundingDINO.doc, question=params) response = self.client.chat.completions.create( - model="gpt-4-turbo-preview", # no need to use vision model here - response_format={"type": "json_object"}, + model=self.model_name, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, + {"role": "user", "content": params}, ], ) try: - prompt = json.loads(cast(str, response.choices[0].message.content))[ - "prompt" + params = json.loads(cast(str, response.choices[0].message.content))[ + "Parameters" ] except json.JSONDecodeError: _LOGGER.error( @@ -144,13 +142,12 @@ def generate_detector(self, prompt: str) -> ImageTool: ) raise ValueError("Failed to decode response") - return GroundingDINO(prompt) + return GroundingDINO(**params) def generate_segmentor(self, prompt: str) -> ImageTool: prompt = CHOOSE_PARAMS.format(api_doc=GroundingSAM.doc, question=prompt) response = self.client.chat.completions.create( - model="gpt-4-turbo-preview", # no need to use vision model here - response_format={"type": "json_object"}, + model=self.model_name, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 474eba1e..86fb5d93 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,16 +1,28 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Union, cast +from typing import Any, Dict, List, Tuple, Union, cast import requests from PIL.Image import Image as ImageType -from vision_agent.image_utils import convert_to_b64 +from vision_agent.image_utils import convert_to_b64, get_image_size _LOGGER = logging.getLogger(__name__) +def normalize_bbox( + bbox: List[Union[int, float]], image_size: Tuple[int, ...] +) -> List[float]: + r"""Normalize the bounding box coordinates to be between 0 and 1.""" + x1, y1, x2, y2 = bbox + x1 = x1 / image_size[1] + y1 = y1 / image_size[0] + x2 = x2 / image_size[1] + y2 = y2 / image_size[0] + return [x1, y1, x2, y2] + + class ImageTool(ABC): @abstractmethod def __call__(self, image: Union[str, ImageType]) -> List[Dict]: @@ -42,12 +54,18 @@ class GroundingDINO(ImageTool): 'Example 1: User Question: "Can you build me a car detector?" {{"Parameters":{{"prompt": "car"}}}}\n' 'Example 2: User Question: "Can you detect the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n' 'Exmaple 3: User Question: "Can you build me a tool that detects red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n' + "The tool returns a list of dictionaries, each containing the following keys:\n" + " - 'lable': The label of the detected object.\n" + " - 'score': The confidence score of the detection.\n" + " - 'bbox': The bounding box of the detected object. The box coordinates are normalize to [0, 1]\n" + "An example output would be: [{'label': ['car'], 'score': [0.99], 'bbox': [[0.1, 0.2, 0.3, 0.4]]}]\n" ) def __init__(self, prompt: str): self.prompt = prompt def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]: + image_size = get_image_size(image) image_b64 = convert_to_b64(image) data = { "prompt": self.prompt, @@ -59,9 +77,18 @@ def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]: json=data, ) resp_json: Dict[str, Any] = res.json() - if resp_json["statusCode"] != 200: + if ( + "statusCode" in resp_json and resp_json["statusCode"] != 200 + ) or "statusCode" not in resp_json: _LOGGER.error(f"Request failed: {resp_json}") - return cast(List[Dict], resp_json["data"]) + return cast(List[Dict], [resp_json]) + resp_data = resp_json["data"] + for elt in resp_data: + if "bboxes" in elt: + elt["bboxes"] = [ + normalize_bbox(box, image_size) for box in elt["bboxes"] + ] + return cast(List[Dict], resp_data) class GroundingSAM(ImageTool):