diff --git a/tests/unit/test_lmm.py b/tests/unit/test_lmm.py index c954b173..97788d4f 100644 --- a/tests/unit/test_lmm.py +++ b/tests/unit/test_lmm.py @@ -1,8 +1,6 @@ import json import tempfile -from unittest.mock import patch -import numpy as np import pytest from PIL import Image @@ -163,60 +161,3 @@ def test_chat_ollama_mock(chat_ollama_lmm_mock): # noqa: F811 assert response == "mocked response" call_args = json.loads(chat_ollama_lmm_mock.call_args.kwargs["data"]) assert call_args["messages"][0]["content"] == "test prompt" - - -@pytest.mark.parametrize( - "openai_lmm_mock", - ['{"Parameters": {"prompt": "cat"}}'], - indirect=["openai_lmm_mock"], -) -def test_generate_classifier(openai_lmm_mock): # noqa: F811 - with patch("vision_agent.tools.clip") as clip_mock: - clip_mock.return_value = "test" - clip_mock.__name__ = "clip" - clip_mock.__doc__ = "clip" - - lmm = OpenAILMM() - prompt = "Can you generate a cat classifier?" - classifier = lmm.generate_classifier(prompt) - dummy_image = np.zeros((10, 10, 3)).astype(np.uint8) - classifier(dummy_image) - assert clip_mock.call_args[0][1] == "cat" - - -@pytest.mark.parametrize( - "openai_lmm_mock", - ['{"Parameters": {"prompt": "cat"}}'], - indirect=["openai_lmm_mock"], -) -def test_generate_detector(openai_lmm_mock): # noqa: F811 - with patch("vision_agent.tools.owl_v2") as owl_v2_mock: - owl_v2_mock.return_value = "test" - owl_v2_mock.__name__ = "owl_v2" - owl_v2_mock.__doc__ = "owl_v2" - - lmm = OpenAILMM() - prompt = "Can you generate a cat classifier?" - detector = lmm.generate_detector(prompt) - dummy_image = np.zeros((10, 10, 3)).astype(np.uint8) - detector(dummy_image) - assert owl_v2_mock.call_args[0][0] == "cat" - - -@pytest.mark.parametrize( - "openai_lmm_mock", - ['{"Parameters": {"prompt": "cat"}}'], - indirect=["openai_lmm_mock"], -) -def test_generate_segmentor(openai_lmm_mock): # noqa: F811 - with patch("vision_agent.tools.grounding_sam") as grounding_sam_mock: - grounding_sam_mock.return_value = "test" - grounding_sam_mock.__name__ = "grounding_sam" - grounding_sam_mock.__doc__ = "grounding_sam" - - lmm = OpenAILMM() - prompt = "Can you generate a cat classifier?" - segmentor = lmm.generate_segmentor(prompt) - dummy_image = np.zeros((10, 10, 3)).astype(np.uint8) - segmentor(dummy_image) - assert grounding_sam_mock.call_args[0][0] == "cat" diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index e68666d2..4f42380c 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -1,77 +1,36 @@ -import base64 -import io import json import logging import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, Dict, Iterator, List, Optional, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Sequence, Union, cast import anthropic import requests from anthropic.types import ImageBlockParam, MessageParam, TextBlockParam from openai import AzureOpenAI, OpenAI -from PIL import Image -import vision_agent.tools as T -from vision_agent.tools.prompts import CHOOSE_PARAMS, SYSTEM_PROMPT +from vision_agent.utils.image_utils import encode_media from .types import Message _LOGGER = logging.getLogger(__name__) -def encode_image_bytes(image: bytes) -> str: - image = Image.open(io.BytesIO(image)).convert("RGB") # type: ignore - buffer = io.BytesIO() - image.save(buffer, format="PNG") # type: ignore - encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") - return encoded_image - - -def encode_media(media: Union[str, Path]) -> str: - if type(media) is str and media.startswith(("http", "https")): - # for mp4 video url, we assume there is a same url but ends with png - # vision-agent-ui will upload this png when uploading the video - if media.endswith((".mp4", "mov")) and media.find("vision-agent-dev.s3") != -1: - return media[:-4] + ".png" - return media - extension = "png" - extension = Path(media).suffix - if extension.lower() not in { - ".jpg", - ".jpeg", - ".png", - ".webp", - ".bmp", - ".mp4", - ".mov", - }: - raise ValueError(f"Unsupported image extension: {extension}") - - image_bytes = b"" - if extension.lower() in {".mp4", ".mov"}: - frames = T.extract_frames(media) - image = frames[len(frames) // 2] - buffer = io.BytesIO() - Image.fromarray(image[0]).convert("RGB").save(buffer, format="PNG") - image_bytes = buffer.getvalue() - else: - image_bytes = open(media, "rb").read() - return encode_image_bytes(image_bytes) - - class LMM(ABC): @abstractmethod def generate( - self, prompt: str, media: Optional[List[Union[str, Path]]] = None, **kwargs: Any + self, + prompt: str, + media: Optional[Sequence[Union[str, Path]]] = None, + **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: pass @abstractmethod def chat( self, - chat: List[Message], + chat: Sequence[Message], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: pass @@ -79,7 +38,7 @@ def chat( @abstractmethod def __call__( self, - input: Union[str, List[Message]], + input: Union[str, Sequence[Message]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: pass @@ -111,7 +70,7 @@ def __init__( def __call__( self, - input: Union[str, List[Message]], + input: Union[str, Sequence[Message]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): @@ -120,13 +79,13 @@ def __call__( def chat( self, - chat: List[Message], + chat: Sequence[Message], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: """Chat with the LMM model. Parameters: - chat (List[Dict[str, str]]): A list of dictionaries containing the chat + chat (Squence[Dict[str, str]]): A list of dictionaries containing the chat messages. The messages can be in the format: [{"role": "user", "content": "Hello!"}, ...] or if it contains media, it should be in the format: @@ -147,6 +106,7 @@ def chat( "url": ( encoded_media if encoded_media.startswith(("http", "https")) + or encoded_media.startswith("data:image/") else f"data:image/png;base64,{encoded_media}" ), "detail": "low", @@ -174,7 +134,7 @@ def f() -> Iterator[Optional[str]]: def generate( self, prompt: str, - media: Optional[List[Union[str, Path]]] = None, + media: Optional[Sequence[Union[str, Path]]] = None, **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: message: List[Dict[str, Any]] = [ @@ -192,7 +152,12 @@ def generate( { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{encoded_media}", + "url": ( + encoded_media + if encoded_media.startswith(("http", "https")) + or encoded_media.startswith("data:image/") + else f"data:image/png;base64,{encoded_media}" + ), "detail": "low", }, }, @@ -214,81 +179,6 @@ def f() -> Iterator[Optional[str]]: else: return cast(str, response.choices[0].message.content) - def generate_classifier(self, question: str) -> Callable: - api_doc = T.get_tool_documentation([T.clip]) - prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question) - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ], - response_format={"type": "json_object"}, - ) - - try: - params = json.loads(cast(str, response.choices[0].message.content))[ - "Parameters" - ] - except json.JSONDecodeError: - _LOGGER.error( - f"Failed to decode response: {response.choices[0].message.content}" - ) - raise ValueError("Failed to decode response") - - return lambda x: T.clip(x, params["prompt"]) - - def generate_detector(self, question: str) -> Callable: - api_doc = T.get_tool_documentation([T.owl_v2]) - prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question) - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ], - response_format={"type": "json_object"}, - ) - - try: - params = json.loads(cast(str, response.choices[0].message.content))[ - "Parameters" - ] - except json.JSONDecodeError: - _LOGGER.error( - f"Failed to decode response: {response.choices[0].message.content}" - ) - raise ValueError("Failed to decode response") - - return lambda x: T.owl_v2(params["prompt"], x) - - def generate_segmentor(self, question: str) -> Callable: - api_doc = T.get_tool_documentation([T.grounding_sam]) - prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question) - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ], - response_format={"type": "json_object"}, - ) - - try: - params = json.loads(cast(str, response.choices[0].message.content))[ - "Parameters" - ] - except json.JSONDecodeError: - _LOGGER.error( - f"Failed to decode response: {response.choices[0].message.content}" - ) - raise ValueError("Failed to decode response") - - return lambda x: T.grounding_sam(params["prompt"], x) - - def generate_image_qa_tool(self, question: str) -> Callable: - return lambda x: T.git_vqa_v2(question, x) - class AzureOpenAILMM(OpenAILMM): def __init__( @@ -362,7 +252,7 @@ def __init__( def __call__( self, - input: Union[str, List[Message]], + input: Union[str, Sequence[Message]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): @@ -371,13 +261,13 @@ def __call__( def chat( self, - chat: List[Message], + chat: Sequence[Message], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: """Chat with the LMM model. Parameters: - chat (List[Dict[str, str]]): A list of dictionaries containing the chat + chat (Sequence[Dict[str, str]]): A list of dictionaries containing the chat messages. The messages can be in the format: [{"role": "user", "content": "Hello!"}, ...] or if it contains media, it should be in the format: @@ -429,7 +319,7 @@ def f() -> Iterator[Optional[str]]: def generate( self, prompt: str, - media: Optional[List[Union[str, Path]]] = None, + media: Optional[Sequence[Union[str, Path]]] = None, **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: url = f"{self.url}/generate" @@ -493,7 +383,7 @@ def __init__( def __call__( self, - input: Union[str, List[Dict[str, Any]]], + input: Union[str, Sequence[Dict[str, Any]]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): @@ -502,7 +392,7 @@ def __call__( def chat( self, - chat: List[Dict[str, Any]], + chat: Sequence[Dict[str, Any]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: messages: List[MessageParam] = [] @@ -551,7 +441,7 @@ def f() -> Iterator[Optional[str]]: def generate( self, prompt: str, - media: Optional[List[Union[str, Path]]] = None, + media: Optional[Sequence[Union[str, Path]]] = None, **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: content: List[Union[TextBlockParam, ImageBlockParam]] = [ diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 43460fbd..90858569 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -16,6 +16,8 @@ clip, closest_box_distance, closest_mask_distance, + countgd_counting, + countgd_example_based_counting, depth_anything_v2, detr_segmentation, dpt_hybrid_midas, @@ -30,6 +32,8 @@ generate_soft_edge_image, get_tool_documentation, git_vqa_v2, + gpt4o_image_vqa, + gpt4o_video_vqa, grounding_dino, grounding_sam, ixc25_image_vqa, @@ -37,13 +41,11 @@ load_image, loca_visual_prompt_counting, loca_zero_shot_counting, - countgd_counting, - countgd_example_based_counting, ocr, overlay_bounding_boxes, + overlay_counting_results, overlay_heat_map, overlay_segmentation_masks, - overlay_counting_results, owl_v2, save_image, save_json, diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index a14443bd..67306c9d 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,6 +1,6 @@ -import os import inspect import logging +import os from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple import pandas as pd @@ -10,10 +10,10 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry +from vision_agent.tools.tools_types import BoundingBoxes from vision_agent.utils.exceptions import RemoteToolCallFailed from vision_agent.utils.execute import Error, MimeType from vision_agent.utils.type_defs import LandingaiAPIKey -from vision_agent.tools.tools_types import BoundingBoxes _LOGGER = logging.getLogger(__name__) _LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 958b2cf6..8dddf8bc 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -13,26 +13,27 @@ import numpy as np import requests from moviepy.editor import ImageSequenceClip -from PIL import Image, ImageDraw, ImageFont, ImageEnhance +from PIL import Image, ImageDraw, ImageEnhance, ImageFont from pillow_heif import register_heif_opener # type: ignore from pytube import YouTube # type: ignore from vision_agent.clients.landing_public_api import LandingPublicAPI +from vision_agent.lmm.lmm import OpenAILMM from vision_agent.tools.tool_utils import ( + filter_bboxes_by_threshold, get_tool_descriptions, get_tool_documentation, get_tools_df, get_tools_info, send_inference_request, send_task_inference_request, - filter_bboxes_by_threshold, ) from vision_agent.tools.tools_types import ( FineTuning, Florence2FtRequest, JobStatus, - PromptTask, ODResponseData, + PromptTask, ) from vision_agent.utils import extract_frames_from_video from vision_agent.utils.exceptions import FineTuneModelIsNotReady @@ -42,6 +43,7 @@ convert_quad_box_to_bbox, convert_to_b64, denormalize_bbox, + encode_image_bytes, frames_to_bytes, get_image_size, normalize_bbox, @@ -691,6 +693,69 @@ def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str: return cast(str, data["answer"]) +def gpt4o_image_vqa(prompt: str, image: np.ndarray) -> str: + """'gpt4o_image_vqa' is a tool that can answer any questions about arbitrary images + including regular images or images of documents or presentations. It returns text + as an answer to the question. + + Parameters: + prompt (str): The question about the image + image (np.ndarray): The reference image used for the question + + Returns: + str: A string which is the answer to the given prompt. + + Example + ------- + >>> gpt4o_image_vqa('What is the cat doing?', image) + 'drinking milk' + """ + + lmm = OpenAILMM() + buffer = io.BytesIO() + Image.fromarray(image).save(buffer, format="PNG") + image_bytes = buffer.getvalue() + image_b64 = "data:image/png;base64," + encode_image_bytes(image_bytes) + resp = lmm.generate(prompt, [image_b64]) + return cast(str, resp) + + +def gpt4o_video_vqa(prompt: str, frames: List[np.ndarray]) -> str: + """'gpt4o_video_vqa' is a tool that can answer any questions about arbitrary videos + including regular videos or videos of documents or presentations. It returns text + as an answer to the question. + + Parameters: + prompt (str): The question about the video + frames (List[np.ndarray]): The reference frames used for the question + + Returns: + str: A string which is the answer to the given prompt. + + Example + ------- + >>> gpt4o_video_vqa('Which football player made the goal?', frames) + 'Lionel Messi' + """ + + lmm = OpenAILMM() + + if len(frames) > 10: + step = len(frames) / 10 + frames = [frames[int(i * step)] for i in range(10)] + + frames_b64 = [] + for frame in frames: + buffer = io.BytesIO() + Image.fromarray(frame).save(buffer, format="PNG") + image_bytes = buffer.getvalue() + image_b64 = "data:image/png;base64," + encode_image_bytes(image_bytes) + frames_b64.append(image_b64) + + resp = lmm.generate(prompt, frames_b64) + return cast(str, resp) + + def git_vqa_v2(prompt: str, image: np.ndarray) -> str: """'git_vqa_v2' is a tool that can answer questions about the visual contents of an image given a question and an image. It returns an answer to the diff --git a/vision_agent/tools/tools_types.py b/vision_agent/tools/tools_types.py index f61c2cf1..6ebcf468 100644 --- a/vision_agent/tools/tools_types.py +++ b/vision_agent/tools/tools_types.py @@ -1,8 +1,8 @@ from enum import Enum +from typing import List, Optional, Tuple, Union from uuid import UUID -from typing import List, Tuple, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo +from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, field_serializer class BboxInput(BaseModel): diff --git a/vision_agent/utils/image_utils.py b/vision_agent/utils/image_utils.py index f0113c9f..3612592d 100644 --- a/vision_agent/utils/image_utils.py +++ b/vision_agent/utils/image_utils.py @@ -13,6 +13,8 @@ from PIL import Image, ImageDraw, ImageFont from PIL.Image import Image as ImageType +from vision_agent.utils import extract_frames_from_video + COLORS = [ (158, 218, 229), (219, 219, 141), @@ -172,6 +174,51 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: ) +def encode_image_bytes(image: bytes) -> str: + image = Image.open(io.BytesIO(image)).convert("RGB") # type: ignore + buffer = io.BytesIO() + image.save(buffer, format="PNG") # type: ignore + encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") + return encoded_image + + +def encode_media(media: Union[str, Path]) -> str: + if isinstance(media, str) and media.startswith(("http", "https")): + # for mp4 video url, we assume there is a same url but ends with png + # vision-agent-ui will upload this png when uploading the video + if media.endswith((".mp4", "mov")) and media.find("vision-agent-dev.s3") != -1: + return media[:-4] + ".png" + return media + + # if media is already a base64 encoded image return + if isinstance(media, str) and media.startswith("data:image/"): + return media + + extension = "png" + extension = Path(media).suffix + if extension.lower() not in { + ".jpg", + ".jpeg", + ".png", + ".webp", + ".bmp", + ".mp4", + ".mov", + }: + raise ValueError(f"Unsupported image extension: {extension}") + + image_bytes = b"" + if extension.lower() in {".mp4", ".mov"}: + frames = extract_frames_from_video(str(media), fps=1) + image = frames[len(frames) // 2] + buffer = io.BytesIO() + Image.fromarray(image[0]).convert("RGB").save(buffer, format="PNG") + image_bytes = buffer.getvalue() + else: + image_bytes = open(media, "rb").read() + return encode_image_bytes(image_bytes) + + def denormalize_bbox( bbox: List[Union[int, float]], image_size: Tuple[int, ...] ) -> List[float]: