Skip to content

Commit

Permalink
Minor improvements and fixes for code generation (#114)
Browse files Browse the repository at this point in the history
1. Add save_video tool for better video serialization
2. Save image as an intermediate output when save_image is called
3. Add default imports from typing.
4. Keep the e2b sandbox alive for 5min when it runs a command.
  • Loading branch information
humpydonkey authored Jun 5, 2024
1 parent f6d6748 commit cff0756
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 33 deletions.
2 changes: 1 addition & 1 deletion examples/custom_tools/run_custom_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from template_match import template_matching_with_rotation

import vision_agent as va
from template_match import template_matching_with_rotation
from vision_agent.utils.image_utils import get_image_size, normalize_bbox


Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
_LOGGER = logging.getLogger(__name__)
_MAX_TABULATE_COL_WIDTH = 80
_CONSOLE = Console()
_DEFAULT_IMPORT = "\n".join(T.__new_tools__)
_DEFAULT_IMPORT = "\n".join(T.__new_tools__) + "\n".join(
[
"from typing import *",
]
)


def get_diff(before: str, after: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
overlay_segmentation_masks,
save_image,
save_json,
save_video_to_result,
save_video,
visual_prompt_counting,
zero_shot_counting,
)
Expand Down
61 changes: 36 additions & 25 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import tempfile
from importlib import resources
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

import cv2
import numpy as np
import pandas as pd
import requests
from moviepy.editor import ImageSequenceClip
from PIL import Image, ImageDraw, ImageFont

from vision_agent.tools.tool_utils import _send_inference_request
Expand Down Expand Up @@ -545,24 +546,49 @@ def save_image(image: np.ndarray) -> str:
>>> save_image(image)
"/tmp/tmpabc123.png"
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8))
display(pil_image)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
pil_image = Image.fromarray(image.astype(np.uint8))
pil_image.save(f, "PNG")
return f.name


def save_video_to_result(video_uri: str) -> None:
"""'save_video_to_result' a utility function that saves a video into the result of the code execution (as an intermediate output).
This function is required to run if user wants to visualize the video generated by the code.
def save_video(
frames: List[np.ndarray], output_video_path: Optional[str] = None, fps: float = 4
) -> str:
"""'save_video' is a utility function that saves a list of frames as a mp4 video file on disk.
Parameters:
video_uri (str): The URI to the video file. Currently only local file paths are supported.
frames (list[np.ndarray]): A list of frames to save.
output_video_path (str): The path to save the video file. If not provided, a temporary file will be created.
fps (float): The number of frames composes a second in the video.
Returns:
str: The path to the saved video file.
Example
-------
>>> save_video_to_result("path/to/video.mp4")
>>> save_video(frames)
"/tmp/tmpvideo123.mp4"
"""
if fps <= 0:
_LOGGER.warning(f"Invalid fps value: {fps}. Setting fps to 4 (default value).")
fps = 4
with ImageSequenceClip(frames, fps=fps) as video:
if output_video_path:
f = open(output_video_path, "wb")
else:
f = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) # type: ignore
video.write_videofile(f.name, codec="libx264")
f.close()
_save_video_to_result(f.name)
return f.name


def _save_video_to_result(video_uri: str) -> None:
"""Saves a video into the result of the code execution (as an intermediate output)."""
from IPython.display import display

serializer = FileSerializer(video_uri)
Expand Down Expand Up @@ -595,8 +621,6 @@ def overlay_bounding_boxes(
image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
)
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8))

if len(set([box["label"] for box in bboxes])) > len(COLORS):
Expand Down Expand Up @@ -634,9 +658,6 @@ def overlay_bounding_boxes(
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
draw.text((box[0], box[1]), text, fill="black", font=font)

pil_image = pil_image.convert("RGB")
display(pil_image)
return np.array(pil_image)


Expand Down Expand Up @@ -668,8 +689,6 @@ def overlay_segmentation_masks(
}],
)
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA")

if len(set([mask["label"] for mask in masks])) > len(COLORS):
Expand All @@ -690,9 +709,6 @@ def overlay_segmentation_masks(
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)

pil_image = pil_image.convert("RGB")
display(pil_image)
return np.array(pil_image)


Expand Down Expand Up @@ -723,8 +739,6 @@ def overlay_heat_map(
},
)
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")

if "heat_map" not in heat_map or len(heat_map["heat_map"]) == 0:
Expand All @@ -740,10 +754,7 @@ def overlay_heat_map(
combined = Image.alpha_composite(
pil_image.convert("RGBA"), overlay.resize(pil_image.size)
)

pil_image = combined.convert("RGB")
display(pil_image)
return np.array(pil_image)
return np.array(combined)


def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str:
Expand Down Expand Up @@ -805,7 +816,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
save_json,
load_image,
save_image,
save_video_to_result,
save_video,
overlay_bounding_boxes,
overlay_segmentation_masks,
overlay_heat_map,
Expand All @@ -818,7 +829,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
save_json,
load_image,
save_image,
save_video_to_result,
save_video,
overlay_bounding_boxes,
overlay_segmentation_masks,
overlay_heat_map,
Expand Down
3 changes: 3 additions & 0 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ def download_file(self, file_path: str) -> Path:


class E2BCodeInterpreter(CodeInterpreter):
KEEP_ALIVE_SEC: int = 300

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set"
Expand Down Expand Up @@ -432,6 +434,7 @@ def restart_kernel(self) -> None:
retry=tenacity.retry_if_exception_type(TimeoutError),
)
def exec_cell(self, code: str) -> Execution:
self.interpreter.keep_alive(E2BCodeInterpreter.KEEP_ALIVE_SEC)
execution = self.interpreter.notebook.exec_cell(code, timeout=self.timeout)
return Execution.from_e2b_execution(execution)

Expand Down
2 changes: 0 additions & 2 deletions vision_agent/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def play_video(video_base64: str) -> None:
# Display the first frame and wait for any key press to start the video
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cv2.imshow("Video Player", frame)
_LOGGER.info(f"Press any key to start playing the video: {temp_video_path}")
cv2.waitKey(0) # Wait for any key press
Expand All @@ -40,7 +39,6 @@ def play_video(video_base64: str) -> None:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cv2.imshow("Video Player", frame)
# Press 'q' to exit the video
if cv2.waitKey(200) & 0xFF == ord("q"):
Expand Down

0 comments on commit cff0756

Please sign in to comment.