Skip to content

Commit

Permalink
feat: use media url directly in both LMM and code sandbox (#215)
Browse files Browse the repository at this point in the history
* do not upload to code_interpreter

* endcode_media support url ad mp4

* load_image

* remove print

* add comment for video associated png

* minor revert

* lint

* backwards

* lint

* also save video

* more strict check for vision-agent-ui

* fix lint
  • Loading branch information
yzld2002 authored Aug 29, 2024
1 parent a8e4b62 commit 5df3ced
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
7 changes: 6 additions & 1 deletion vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,12 @@ def chat_with_workflow(
for chat_i in chat:
if "media" in chat_i:
for media in chat_i["media"]:
media = code_interpreter.upload_file(media)
media = (
media
if type(media) is str
and media.startswith(("http", "https"))
else code_interpreter.upload_file(media)
)
chat_i["content"] += f" Media name {media}" # type: ignore
media_list.append(media)

Expand Down
15 changes: 11 additions & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def encode_image_bytes(image: bytes) -> str:


def encode_media(media: Union[str, Path]) -> str:
if type(media) is str and media.startswith(("http", "https")):
# for mp4 video url, we assume there is a same url but ends with png
# vision-agent-ui will upload this png when uploading the video
if media.endswith((".mp4", "mov")) and media.find("vision-agent-dev.s3") != -1:
return media[:-4] + ".png"
return media
extension = "png"
extension = Path(media).suffix
if extension.lower() not in {
Expand Down Expand Up @@ -138,7 +144,11 @@ def chat(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_media}",
"url": (
encoded_media
if encoded_media.startswith(("http", "https"))
else f"data:image/png;base64,{encoded_media}"
),
"detail": "low",
},
},
Expand Down Expand Up @@ -390,7 +400,6 @@ def chat(
tmp_kwargs = self.kwargs | kwargs
data.update(tmp_kwargs)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

json_data = json.dumps(data)

def f() -> Iterator[Optional[str]]:
Expand Down Expand Up @@ -424,7 +433,6 @@ def generate(
media: Optional[List[Union[str, Path]]] = None,
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:

url = f"{self.url}/generate"
data: Dict[str, Any] = {
"model": self.model_name,
Expand All @@ -439,7 +447,6 @@ def generate(
tmp_kwargs = self.kwargs | kwargs
data.update(tmp_kwargs)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

json_data = json.dumps(data)

def f() -> Iterator[Optional[str]]:
Expand Down
20 changes: 18 additions & 2 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import io
import json
import logging
Expand All @@ -14,6 +15,7 @@
from PIL import Image, ImageDraw, ImageFont
from pillow_heif import register_heif_opener # type: ignore
from pytube import YouTube # type: ignore
import urllib.request

from vision_agent.clients.landing_public_api import LandingPublicAPI
from vision_agent.tools.tool_utils import (
Expand Down Expand Up @@ -1220,6 +1222,13 @@ def extract_frames(
video_file_path = video.download(output_path=temp_dir)

return extract_frames_from_video(video_file_path, fps)
elif str(video_uri).startswith(("http", "https")):
_, image_suffix = os.path.splitext(video_uri)
with tempfile.NamedTemporaryFile(delete=False, suffix=image_suffix) as tmp_file:
# Download the video and save it to the temporary file
with urllib.request.urlopen(str(video_uri)) as response:
tmp_file.write(response.read())
return extract_frames_from_video(tmp_file.name, fps)

return extract_frames_from_video(str(video_uri), fps)

Expand Down Expand Up @@ -1250,10 +1259,10 @@ def default(self, obj: Any): # type: ignore


def load_image(image_path: str) -> np.ndarray:
"""'load_image' is a utility function that loads an image from the given file path string.
"""'load_image' is a utility function that loads an image from the given file path string or an URL.
Parameters:
image_path (str): The path to the image.
image_path (str): The path or URL to the image.
Returns:
np.ndarray: The image as a NumPy array.
Expand All @@ -1265,6 +1274,13 @@ def load_image(image_path: str) -> np.ndarray:
# NOTE: sometimes the generated code pass in a NumPy array
if isinstance(image_path, np.ndarray):
return image_path
if image_path.startswith(("http", "https")):
_, image_suffix = os.path.splitext(image_path)
with tempfile.NamedTemporaryFile(delete=False, suffix=image_suffix) as tmp_file:
# Download the image and save it to the temporary file
with urllib.request.urlopen(image_path) as response:
tmp_file.write(response.read())
image_path = tmp_file.name
image = Image.open(image_path).convert("RGB")
return np.array(image)

Expand Down

0 comments on commit 5df3ced

Please sign in to comment.