diff --git a/tests/unit/tools/test_tools.py b/tests/unit/tools/test_tools.py index b2f1a87a..5fc82b84 100644 --- a/tests/unit/tools/test_tools.py +++ b/tests/unit/tools/test_tools.py @@ -1,25 +1,69 @@ -# Generated by CodiumAI +import os +import tempfile from pathlib import Path import numpy as np -from vision_agent.tools.tools import save_video +from vision_agent.tools.tools import save_image, save_video -class TestSaveVideo: - def test_saves_frames_without_output_path(self): - frames = [ - np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) for _ in range(10) - ] - output_path = save_video(frames) - assert Path(output_path).exists() +def test_saves_frames_without_output_path(): + frames = [ + np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) for _ in range(10) + ] + output_path = save_video(frames) + assert Path(output_path).exists() + os.remove(output_path) + - def test_saves_frames_with_output_path(self, tmp_path): - frames = [ - np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) for _ in range(10) - ] - video_output_path = str(tmp_path / "output.mp4") - output_path = save_video(frames, video_output_path) +def test_saves_frames_with_output_path(): + frames = [ + np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) for _ in range(10) + ] - assert output_path == video_output_path + with tempfile.TemporaryDirectory() as tmp_dir: + video_output_path = Path(tmp_dir) / "output.mp4" + output_path = save_video(frames, str(video_output_path)) + + assert output_path == str(video_output_path) assert Path(output_path).exists() + + +def test_save_null_image(): + image = None + try: + save_image(image, "tmp.jpg") + except ValueError as e: + assert str(e) == "The image is not a valid NumPy array with shape (H, W, C)" + + +def test_save_empty_image(): + image = np.zeros((0, 0, 3), dtype=np.uint8) + try: + save_image(image, "tmp.jpg") + except ValueError as e: + assert str(e) == "The image is not a valid NumPy array with shape (H, W, C)" + + +def test_save_null_video(): + frames = None + try: + save_video(frames, "tmp.mp4") + except ValueError as e: + assert str(e) == "Frames must be a list of NumPy arrays" + +def test_save_empty_list(): + frames = [] + try: + save_video(frames, "tmp.mp4") + except ValueError as e: + assert str(e) == "Frames must be a list of NumPy arrays" + + +def test_save_invalid_frame(): + frames = [np.zeros((0, 0, 3), dtype=np.uint8)] + try: + save_video(frames, "tmp.mp4") + except ValueError as e: + assert str(e) == "Frame is not a valid NumPy array with shape (H, W, C)" + diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 7d881921..486e21a2 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1808,6 +1808,9 @@ def save_image(image: np.ndarray, file_path: str) -> None: """ from IPython.display import display + if not isinstance(image, np.ndarray) or (image.shape[0] == 0 and image.shape[1] == 0): + raise ValueError("The image is not a valid NumPy array with shape (H, W, C)") + pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB") display(pil_image) pil_image.save(file_path) @@ -1834,6 +1837,15 @@ def save_video( if fps <= 0: raise ValueError(f"fps must be greater than 0 got {fps}") + if not isinstance(frames, list) or len(frames) == 0: + raise ValueError("Frames must be a list of NumPy arrays") + + for frame in frames: + if not isinstance(frame, np.ndarray) or ( + frame.shape[0] == 0 and frame.shape[1] == 0 + ): + raise ValueError("The frame is not a valid NumPy array with shape (H, W, C)") + if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile( delete=False, suffix=".mp4"