diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index e68666d2..0595872a 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -1,77 +1,33 @@ -import base64 -import io import json import logging import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, Dict, Iterator, List, Optional, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Union, cast, Sequence import anthropic import requests from anthropic.types import ImageBlockParam, MessageParam, TextBlockParam from openai import AzureOpenAI, OpenAI -from PIL import Image -import vision_agent.tools as T -from vision_agent.tools.prompts import CHOOSE_PARAMS, SYSTEM_PROMPT +from vision_agent.utils.image_utils import encode_media from .types import Message _LOGGER = logging.getLogger(__name__) -def encode_image_bytes(image: bytes) -> str: - image = Image.open(io.BytesIO(image)).convert("RGB") # type: ignore - buffer = io.BytesIO() - image.save(buffer, format="PNG") # type: ignore - encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") - return encoded_image - - -def encode_media(media: Union[str, Path]) -> str: - if type(media) is str and media.startswith(("http", "https")): - # for mp4 video url, we assume there is a same url but ends with png - # vision-agent-ui will upload this png when uploading the video - if media.endswith((".mp4", "mov")) and media.find("vision-agent-dev.s3") != -1: - return media[:-4] + ".png" - return media - extension = "png" - extension = Path(media).suffix - if extension.lower() not in { - ".jpg", - ".jpeg", - ".png", - ".webp", - ".bmp", - ".mp4", - ".mov", - }: - raise ValueError(f"Unsupported image extension: {extension}") - - image_bytes = b"" - if extension.lower() in {".mp4", ".mov"}: - frames = T.extract_frames(media) - image = frames[len(frames) // 2] - buffer = io.BytesIO() - Image.fromarray(image[0]).convert("RGB").save(buffer, format="PNG") - image_bytes = buffer.getvalue() - else: - image_bytes = open(media, "rb").read() - return encode_image_bytes(image_bytes) - - class LMM(ABC): @abstractmethod def generate( - self, prompt: str, media: Optional[List[Union[str, Path]]] = None, **kwargs: Any + self, prompt: str, media: Optional[Sequence[Union[str, Path]]] = None, **kwargs: Any ) -> Union[str, Iterator[Optional[str]]]: pass @abstractmethod def chat( self, - chat: List[Message], + chat: Sequence[Message], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: pass @@ -79,7 +35,7 @@ def chat( @abstractmethod def __call__( self, - input: Union[str, List[Message]], + input: Union[str, Sequence[Message]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: pass @@ -111,7 +67,7 @@ def __init__( def __call__( self, - input: Union[str, List[Message]], + input: Union[str, Sequence[Message]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): @@ -120,13 +76,13 @@ def __call__( def chat( self, - chat: List[Message], + chat: Sequence[Message], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: """Chat with the LMM model. Parameters: - chat (List[Dict[str, str]]): A list of dictionaries containing the chat + chat (Squence[Dict[str, str]]): A list of dictionaries containing the chat messages. The messages can be in the format: [{"role": "user", "content": "Hello!"}, ...] or if it contains media, it should be in the format: @@ -147,6 +103,7 @@ def chat( "url": ( encoded_media if encoded_media.startswith(("http", "https")) + or encoded_media.startswith("data:image/") else f"data:image/png;base64,{encoded_media}" ), "detail": "low", @@ -174,7 +131,7 @@ def f() -> Iterator[Optional[str]]: def generate( self, prompt: str, - media: Optional[List[Union[str, Path]]] = None, + media: Optional[Sequence[Union[str, Path]]] = None, **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: message: List[Dict[str, Any]] = [ @@ -192,7 +149,12 @@ def generate( { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{encoded_media}", + "url": ( + encoded_media + if encoded_media.startswith(("http", "https")) + or encoded_media.startswith("data:image/") + else f"data:image/png;base64,{encoded_media}" + ), "detail": "low", }, }, @@ -214,81 +176,6 @@ def f() -> Iterator[Optional[str]]: else: return cast(str, response.choices[0].message.content) - def generate_classifier(self, question: str) -> Callable: - api_doc = T.get_tool_documentation([T.clip]) - prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question) - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ], - response_format={"type": "json_object"}, - ) - - try: - params = json.loads(cast(str, response.choices[0].message.content))[ - "Parameters" - ] - except json.JSONDecodeError: - _LOGGER.error( - f"Failed to decode response: {response.choices[0].message.content}" - ) - raise ValueError("Failed to decode response") - - return lambda x: T.clip(x, params["prompt"]) - - def generate_detector(self, question: str) -> Callable: - api_doc = T.get_tool_documentation([T.owl_v2]) - prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question) - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ], - response_format={"type": "json_object"}, - ) - - try: - params = json.loads(cast(str, response.choices[0].message.content))[ - "Parameters" - ] - except json.JSONDecodeError: - _LOGGER.error( - f"Failed to decode response: {response.choices[0].message.content}" - ) - raise ValueError("Failed to decode response") - - return lambda x: T.owl_v2(params["prompt"], x) - - def generate_segmentor(self, question: str) -> Callable: - api_doc = T.get_tool_documentation([T.grounding_sam]) - prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=question) - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ], - response_format={"type": "json_object"}, - ) - - try: - params = json.loads(cast(str, response.choices[0].message.content))[ - "Parameters" - ] - except json.JSONDecodeError: - _LOGGER.error( - f"Failed to decode response: {response.choices[0].message.content}" - ) - raise ValueError("Failed to decode response") - - return lambda x: T.grounding_sam(params["prompt"], x) - - def generate_image_qa_tool(self, question: str) -> Callable: - return lambda x: T.git_vqa_v2(question, x) - class AzureOpenAILMM(OpenAILMM): def __init__( @@ -362,7 +249,7 @@ def __init__( def __call__( self, - input: Union[str, List[Message]], + input: Union[str, Sequence[Message]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): @@ -371,13 +258,13 @@ def __call__( def chat( self, - chat: List[Message], + chat: Sequence[Message], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: """Chat with the LMM model. Parameters: - chat (List[Dict[str, str]]): A list of dictionaries containing the chat + chat (Sequence[Dict[str, str]]): A list of dictionaries containing the chat messages. The messages can be in the format: [{"role": "user", "content": "Hello!"}, ...] or if it contains media, it should be in the format: @@ -429,7 +316,7 @@ def f() -> Iterator[Optional[str]]: def generate( self, prompt: str, - media: Optional[List[Union[str, Path]]] = None, + media: Optional[Sequence[Union[str, Path]]] = None, **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: url = f"{self.url}/generate" @@ -493,7 +380,7 @@ def __init__( def __call__( self, - input: Union[str, List[Dict[str, Any]]], + input: Union[str, Sequence[Dict[str, Any]]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: if isinstance(input, str): @@ -502,7 +389,7 @@ def __call__( def chat( self, - chat: List[Dict[str, Any]], + chat: Sequence[Dict[str, Any]], **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: messages: List[MessageParam] = [] @@ -551,7 +438,7 @@ def f() -> Iterator[Optional[str]]: def generate( self, prompt: str, - media: Optional[List[Union[str, Path]]] = None, + media: Optional[Sequence[Union[str, Path]]] = None, **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: content: List[Union[TextBlockParam, ImageBlockParam]] = [ diff --git a/vision_agent/utils/image_utils.py b/vision_agent/utils/image_utils.py index f0113c9f..3612592d 100644 --- a/vision_agent/utils/image_utils.py +++ b/vision_agent/utils/image_utils.py @@ -13,6 +13,8 @@ from PIL import Image, ImageDraw, ImageFont from PIL.Image import Image as ImageType +from vision_agent.utils import extract_frames_from_video + COLORS = [ (158, 218, 229), (219, 219, 141), @@ -172,6 +174,51 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: ) +def encode_image_bytes(image: bytes) -> str: + image = Image.open(io.BytesIO(image)).convert("RGB") # type: ignore + buffer = io.BytesIO() + image.save(buffer, format="PNG") # type: ignore + encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") + return encoded_image + + +def encode_media(media: Union[str, Path]) -> str: + if isinstance(media, str) and media.startswith(("http", "https")): + # for mp4 video url, we assume there is a same url but ends with png + # vision-agent-ui will upload this png when uploading the video + if media.endswith((".mp4", "mov")) and media.find("vision-agent-dev.s3") != -1: + return media[:-4] + ".png" + return media + + # if media is already a base64 encoded image return + if isinstance(media, str) and media.startswith("data:image/"): + return media + + extension = "png" + extension = Path(media).suffix + if extension.lower() not in { + ".jpg", + ".jpeg", + ".png", + ".webp", + ".bmp", + ".mp4", + ".mov", + }: + raise ValueError(f"Unsupported image extension: {extension}") + + image_bytes = b"" + if extension.lower() in {".mp4", ".mov"}: + frames = extract_frames_from_video(str(media), fps=1) + image = frames[len(frames) // 2] + buffer = io.BytesIO() + Image.fromarray(image[0]).convert("RGB").save(buffer, format="PNG") + image_bytes = buffer.getvalue() + else: + image_bytes = open(media, "rb").read() + return encode_image_bytes(image_bytes) + + def denormalize_bbox( bbox: List[Union[int, float]], image_size: Tuple[int, ...] ) -> List[float]: