Skip to content

Commit 56764af

Browse files
authored
Support video as an intermediate output and visualization (#112)
* Support video as an intermediate output and local video visualization
1 parent f3d5bcf commit 56764af

File tree

7 files changed

+122
-17
lines changed

7 files changed

+122
-17
lines changed

poetry.lock

Lines changed: 13 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pandas = "2.*"
2525
openai = "1.*"
2626
typing_extensions = "4.*"
2727
moviepy = "1.*"
28-
opencv-python-headless = "4.*"
28+
opencv-python = "4.*"
2929
tabulate = "^0.9.0"
3030
pydantic-settings = "^2.2.1"
3131
scipy = "1.13.*"

vision_agent/agent/vision_agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from vision_agent.utils.execute import CodeInterpreter
2929
from vision_agent.utils.image_utils import b64_to_pil
3030
from vision_agent.utils.sim import Sim
31+
from vision_agent.utils.video import play_video
3132

3233
logging.basicConfig(stream=sys.stdout)
3334
_LOGGER = logging.getLogger(__name__)
@@ -522,6 +523,9 @@ def chat_with_workflow(
522523
for res in execution_result.results:
523524
if res.png:
524525
b64_to_pil(res.png).show()
526+
if res.mp4:
527+
play_video(res.mp4)
528+
525529
return {
526530
"code": code,
527531
"test": test,

vision_agent/tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
overlay_segmentation_masks,
2323
save_image,
2424
save_json,
25+
save_video_to_result,
2526
visual_prompt_counting,
2627
zero_shot_counting,
2728
)

vision_agent/tools/tools.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from vision_agent.tools.tool_utils import _send_inference_request
1717
from vision_agent.utils import extract_frames_from_video
18+
from vision_agent.utils.execute import FileSerializer, MimeType
1819
from vision_agent.utils.image_utils import (
1920
b64_to_pil,
2021
convert_to_b64,
@@ -550,6 +551,29 @@ def save_image(image: np.ndarray) -> str:
550551
return f.name
551552

552553

554+
def save_video_to_result(video_uri: str) -> None:
555+
"""'save_video_to_result' a utility function that saves a video into the result of the code execution (as an intermediate output).
556+
This function is required to run if user wants to visualize the video generated by the code.
557+
558+
Parameters:
559+
video_uri (str): The URI to the video file. Currently only local file paths are supported.
560+
561+
Example
562+
-------
563+
>>> save_video_to_result("path/to/video.mp4")
564+
"""
565+
from IPython.display import display
566+
567+
serializer = FileSerializer(video_uri)
568+
display(
569+
{
570+
MimeType.VIDEO_MP4_B64: serializer.base64(),
571+
MimeType.TEXT_PLAIN: str(serializer),
572+
},
573+
raw=True,
574+
)
575+
576+
553577
def overlay_bounding_boxes(
554578
image: np.ndarray, bboxes: List[Dict[str, Any]]
555579
) -> np.ndarray:
@@ -570,6 +594,8 @@ def overlay_bounding_boxes(
570594
image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
571595
)
572596
"""
597+
from IPython.display import display
598+
573599
pil_image = Image.fromarray(image.astype(np.uint8))
574600

575601
if len(set([box["label"] for box in bboxes])) > len(COLORS):
@@ -606,7 +632,10 @@ def overlay_bounding_boxes(
606632
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
607633
draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
608634
draw.text((box[0], box[1]), text, fill="black", font=font)
609-
return np.array(pil_image.convert("RGB"))
635+
636+
pil_image = pil_image.convert("RGB")
637+
display(pil_image)
638+
return np.array(pil_image)
610639

611640

612641
def overlay_segmentation_masks(
@@ -637,6 +666,8 @@ def overlay_segmentation_masks(
637666
}],
638667
)
639668
"""
669+
from IPython.display import display
670+
640671
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA")
641672

642673
if len(set([mask["label"] for mask in masks])) > len(COLORS):
@@ -656,7 +687,10 @@ def overlay_segmentation_masks(
656687
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
657688
mask_img = Image.fromarray(np_mask.astype(np.uint8))
658689
pil_image = Image.alpha_composite(pil_image, mask_img)
659-
return np.array(pil_image.convert("RGB"))
690+
691+
pil_image = pil_image.convert("RGB")
692+
display(pil_image)
693+
return np.array(pil_image)
660694

661695

662696
def overlay_heat_map(
@@ -686,6 +720,8 @@ def overlay_heat_map(
686720
},
687721
)
688722
"""
723+
from IPython.display import display
724+
689725
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
690726

691727
if "heat_map" not in heat_map or len(heat_map["heat_map"]) == 0:
@@ -701,7 +737,10 @@ def overlay_heat_map(
701737
combined = Image.alpha_composite(
702738
pil_image.convert("RGBA"), overlay.resize(pil_image.size)
703739
)
704-
return np.array(combined.convert("RGB"))
740+
741+
pil_image = combined.convert("RGB")
742+
display(pil_image)
743+
return np.array(pil_image)
705744

706745

707746
def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str:
@@ -763,6 +802,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
763802
save_json,
764803
load_image,
765804
save_image,
805+
save_video_to_result,
766806
overlay_bounding_boxes,
767807
overlay_segmentation_masks,
768808
overlay_heat_map,
@@ -775,6 +815,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
775815
save_json,
776816
load_image,
777817
save_image,
818+
save_video_to_result,
778819
overlay_bounding_boxes,
779820
overlay_segmentation_masks,
780821
overlay_heat_map,

vision_agent/utils/execute.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import atexit
3+
import base64
34
import copy
45
import logging
56
import os
@@ -45,12 +46,31 @@ class MimeType(str, Enum):
4546
IMAGE_SVG = "image/svg+xml"
4647
IMAGE_PNG = "image/png"
4748
IMAGE_JPEG = "image/jpeg"
49+
VIDEO_MP4_B64 = "video/mp4/base64"
4850
APPLICATION_PDF = "application/pdf"
4951
TEXT_LATEX = "text/latex"
5052
APPLICATION_JSON = "application/json"
5153
APPLICATION_JAVASCRIPT = "application/javascript"
5254

5355

56+
class FileSerializer:
57+
"""Adaptor class that allows IPython.display.display() to serialize a file to a base64 string representation."""
58+
59+
def __init__(self, file_uri: str):
60+
self.video_uri = file_uri
61+
assert os.path.isfile(
62+
file_uri
63+
), f"Only support local files currently: {file_uri}"
64+
assert Path(file_uri).exists(), f"File not found: {file_uri}"
65+
66+
def __repr__(self) -> str:
67+
return f"FileSerializer({self.video_uri})"
68+
69+
def base64(self) -> str:
70+
with open(self.video_uri, "rb") as file:
71+
return base64.b64encode(file.read()).decode("utf-8")
72+
73+
5474
class Result:
5575
"""
5676
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
@@ -70,6 +90,7 @@ class Result:
7090
png: Optional[str] = None
7191
jpeg: Optional[str] = None
7292
pdf: Optional[str] = None
93+
mp4: Optional[str] = None
7394
latex: Optional[str] = None
7495
json: Optional[Dict[str, Any]] = None
7596
javascript: Optional[str] = None
@@ -93,6 +114,7 @@ def __init__(self, is_main_result: bool, data: Dict[str, Any]):
93114
self.png = data.pop(MimeType.IMAGE_PNG, None)
94115
self.jpeg = data.pop(MimeType.IMAGE_JPEG, None)
95116
self.pdf = data.pop(MimeType.APPLICATION_PDF, None)
117+
self.mp4 = data.pop(MimeType.VIDEO_MP4_B64, None)
96118
self.latex = data.pop(MimeType.TEXT_LATEX, None)
97119
self.json = data.pop(MimeType.APPLICATION_JSON, None)
98120
self.javascript = data.pop(MimeType.APPLICATION_JAVASCRIPT, None)
@@ -190,6 +212,8 @@ def formats(self) -> Iterable[str]:
190212
formats.append("json")
191213
if self.javascript:
192214
formats.append("javascript")
215+
if self.mp4:
216+
formats.append("mp4")
193217
if self.extra:
194218
formats.extend(iter(self.extra))
195219
return formats

vision_agent/utils/video.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import base64
12
import logging
23
import math
34
import os
45
from concurrent.futures import ProcessPoolExecutor, as_completed
6+
import tempfile
57
from typing import List, Tuple, cast
68

79
import cv2
@@ -14,6 +16,39 @@
1416
_CLIP_LENGTH = 30.0
1517

1618

19+
def play_video(video_base64: str) -> None:
20+
"""Play a video file"""
21+
video_data = base64.b64decode(video_base64)
22+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
23+
temp_video.write(video_data)
24+
temp_video_path = temp_video.name
25+
26+
cap = cv2.VideoCapture(temp_video_path)
27+
if not cap.isOpened():
28+
_LOGGER.error("Error: Could not open video.")
29+
return
30+
31+
# Display the first frame and wait for any key press to start the video
32+
ret, frame = cap.read()
33+
if ret:
34+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
35+
cv2.imshow("Video Player", frame)
36+
_LOGGER.info(f"Press any key to start playing the video: {temp_video_path}")
37+
cv2.waitKey(0) # Wait for any key press
38+
39+
while cap.isOpened():
40+
ret, frame = cap.read()
41+
if not ret:
42+
break
43+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
44+
cv2.imshow("Video Player", frame)
45+
# Press 'q' to exit the video
46+
if cv2.waitKey(200) & 0xFF == ord("q"):
47+
break
48+
cap.release()
49+
cv2.destroyAllWindows()
50+
51+
1752
def extract_frames_from_video(
1853
video_uri: str, fps: float = 0.5, motion_detection_threshold: float = 0.0
1954
) -> List[Tuple[np.ndarray, float]]:

0 commit comments

Comments
 (0)