From 5df3ced38526efcbf4ae61568d58e07e4c3335b3 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Thu, 29 Aug 2024 16:02:41 +0800 Subject: [PATCH] feat: use media url directly in both LMM and code sandbox (#215) * do not upload to code_interpreter * endcode_media support url ad mp4 * load_image * remove print * add comment for video associated png * minor revert * lint * backwards * lint * also save video * more strict check for vision-agent-ui * fix lint --- vision_agent/agent/vision_agent_coder.py | 7 ++++++- vision_agent/lmm/lmm.py | 15 +++++++++++---- vision_agent/tools/tools.py | 20 ++++++++++++++++++-- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 3b3b5f68..7856bdb8 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -718,7 +718,12 @@ def chat_with_workflow( for chat_i in chat: if "media" in chat_i: for media in chat_i["media"]: - media = code_interpreter.upload_file(media) + media = ( + media + if type(media) is str + and media.startswith(("http", "https")) + else code_interpreter.upload_file(media) + ) chat_i["content"] += f" Media name {media}" # type: ignore media_list.append(media) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index e78a0593..15df5ac9 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -30,6 +30,12 @@ def encode_image_bytes(image: bytes) -> str: def encode_media(media: Union[str, Path]) -> str: + if type(media) is str and media.startswith(("http", "https")): + # for mp4 video url, we assume there is a same url but ends with png + # vision-agent-ui will upload this png when uploading the video + if media.endswith((".mp4", "mov")) and media.find("vision-agent-dev.s3") != -1: + return media[:-4] + ".png" + return media extension = "png" extension = Path(media).suffix if extension.lower() not in { @@ -138,7 +144,11 @@ def chat( { "type": "image_url", "image_url": { - "url": f"data:image/png;base64,{encoded_media}", + "url": ( + encoded_media + if encoded_media.startswith(("http", "https")) + else f"data:image/png;base64,{encoded_media}" + ), "detail": "low", }, }, @@ -390,7 +400,6 @@ def chat( tmp_kwargs = self.kwargs | kwargs data.update(tmp_kwargs) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: - json_data = json.dumps(data) def f() -> Iterator[Optional[str]]: @@ -424,7 +433,6 @@ def generate( media: Optional[List[Union[str, Path]]] = None, **kwargs: Any, ) -> Union[str, Iterator[Optional[str]]]: - url = f"{self.url}/generate" data: Dict[str, Any] = { "model": self.model_name, @@ -439,7 +447,6 @@ def generate( tmp_kwargs = self.kwargs | kwargs data.update(tmp_kwargs) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: - json_data = json.dumps(data) def f() -> Iterator[Optional[str]]: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 4e1a0f40..62a1908a 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,3 +1,4 @@ +import os import io import json import logging @@ -14,6 +15,7 @@ from PIL import Image, ImageDraw, ImageFont from pillow_heif import register_heif_opener # type: ignore from pytube import YouTube # type: ignore +import urllib.request from vision_agent.clients.landing_public_api import LandingPublicAPI from vision_agent.tools.tool_utils import ( @@ -1220,6 +1222,13 @@ def extract_frames( video_file_path = video.download(output_path=temp_dir) return extract_frames_from_video(video_file_path, fps) + elif str(video_uri).startswith(("http", "https")): + _, image_suffix = os.path.splitext(video_uri) + with tempfile.NamedTemporaryFile(delete=False, suffix=image_suffix) as tmp_file: + # Download the video and save it to the temporary file + with urllib.request.urlopen(str(video_uri)) as response: + tmp_file.write(response.read()) + return extract_frames_from_video(tmp_file.name, fps) return extract_frames_from_video(str(video_uri), fps) @@ -1250,10 +1259,10 @@ def default(self, obj: Any): # type: ignore def load_image(image_path: str) -> np.ndarray: - """'load_image' is a utility function that loads an image from the given file path string. + """'load_image' is a utility function that loads an image from the given file path string or an URL. Parameters: - image_path (str): The path to the image. + image_path (str): The path or URL to the image. Returns: np.ndarray: The image as a NumPy array. @@ -1265,6 +1274,13 @@ def load_image(image_path: str) -> np.ndarray: # NOTE: sometimes the generated code pass in a NumPy array if isinstance(image_path, np.ndarray): return image_path + if image_path.startswith(("http", "https")): + _, image_suffix = os.path.splitext(image_path) + with tempfile.NamedTemporaryFile(delete=False, suffix=image_suffix) as tmp_file: + # Download the image and save it to the temporary file + with urllib.request.urlopen(image_path) as response: + tmp_file.write(response.read()) + image_path = tmp_file.name image = Image.open(image_path).convert("RGB") return np.array(image)