Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use media url directly in both LMM and code sandbox #215

Merged
merged 12 commits into from
Aug 29, 2024
7 changes: 6 additions & 1 deletion vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,12 @@ def chat_with_workflow(
for chat_i in chat:
if "media" in chat_i:
for media in chat_i["media"]:
media = code_interpreter.upload_file(media)
media = (
media
if type(media) is str
and media.startswith(("http", "https"))
else code_interpreter.upload_file(media)
)
chat_i["content"] += f" Media name {media}" # type: ignore
media_list.append(media)

Expand Down
15 changes: 11 additions & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def encode_image_bytes(image: bytes) -> str:


def encode_media(media: Union[str, Path]) -> str:
if type(media) is str and media.startswith(("http", "https")):
# for mp4 video url, we assume there is a same url but ends with png
# vision-agent-ui will upload this png when uploading the video
if media.endswith((".mp4", "mov")) and media.find("vision-agent-dev.s3") != -1:
return media[:-4] + ".png"
return media
extension = "png"
extension = Path(media).suffix
if extension.lower() not in {
Expand Down Expand Up @@ -138,7 +144,11 @@ def chat(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_media}",
"url": (
encoded_media
if encoded_media.startswith(("http", "https"))
else f"data:image/png;base64,{encoded_media}"
),
"detail": "low",
},
},
Expand Down Expand Up @@ -390,7 +400,6 @@ def chat(
tmp_kwargs = self.kwargs | kwargs
data.update(tmp_kwargs)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

json_data = json.dumps(data)

def f() -> Iterator[Optional[str]]:
Expand Down Expand Up @@ -424,7 +433,6 @@ def generate(
media: Optional[List[Union[str, Path]]] = None,
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:

url = f"{self.url}/generate"
data: Dict[str, Any] = {
"model": self.model_name,
Expand All @@ -439,7 +447,6 @@ def generate(
tmp_kwargs = self.kwargs | kwargs
data.update(tmp_kwargs)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

json_data = json.dumps(data)

def f() -> Iterator[Optional[str]]:
Expand Down
20 changes: 18 additions & 2 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import io
import json
import logging
Expand All @@ -14,6 +15,7 @@
from PIL import Image, ImageDraw, ImageFont
from pillow_heif import register_heif_opener # type: ignore
from pytube import YouTube # type: ignore
import urllib.request

from vision_agent.clients.landing_public_api import LandingPublicAPI
from vision_agent.tools.tool_utils import (
Expand Down Expand Up @@ -1220,6 +1222,13 @@ def extract_frames(
video_file_path = video.download(output_path=temp_dir)

return extract_frames_from_video(video_file_path, fps)
elif str(video_uri).startswith(("http", "https")):
_, image_suffix = os.path.splitext(video_uri)
with tempfile.NamedTemporaryFile(delete=False, suffix=image_suffix) as tmp_file:
# Download the video and save it to the temporary file
with urllib.request.urlopen(str(video_uri)) as response:
tmp_file.write(response.read())
return extract_frames_from_video(tmp_file.name, fps)

return extract_frames_from_video(str(video_uri), fps)

Expand Down Expand Up @@ -1250,10 +1259,10 @@ def default(self, obj: Any): # type: ignore


def load_image(image_path: str) -> np.ndarray:
"""'load_image' is a utility function that loads an image from the given file path string.
"""'load_image' is a utility function that loads an image from the given file path string or an URL.

Parameters:
image_path (str): The path to the image.
image_path (str): The path or URL to the image.

Returns:
np.ndarray: The image as a NumPy array.
Expand All @@ -1265,6 +1274,13 @@ def load_image(image_path: str) -> np.ndarray:
# NOTE: sometimes the generated code pass in a NumPy array
if isinstance(image_path, np.ndarray):
return image_path
if image_path.startswith(("http", "https")):
_, image_suffix = os.path.splitext(image_path)
with tempfile.NamedTemporaryFile(delete=False, suffix=image_suffix) as tmp_file:
# Download the image and save it to the temporary file
with urllib.request.urlopen(image_path) as response:
tmp_file.write(response.read())
image_path = tmp_file.name
image = Image.open(image_path).convert("RGB")
return np.array(image)

Expand Down
Loading