-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added more error handling for saving files
- Loading branch information
1 parent
96ad669
commit 2499aea
Showing
2 changed files
with
72 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,69 @@ | ||
# Generated by CodiumAI | ||
import os | ||
import tempfile | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
|
||
from vision_agent.tools.tools import save_video | ||
from vision_agent.tools.tools import save_image, save_video | ||
|
||
|
||
class TestSaveVideo: | ||
def test_saves_frames_without_output_path(self): | ||
frames = [ | ||
np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) for _ in range(10) | ||
] | ||
output_path = save_video(frames) | ||
assert Path(output_path).exists() | ||
def test_saves_frames_without_output_path(): | ||
frames = [ | ||
np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) for _ in range(10) | ||
] | ||
output_path = save_video(frames) | ||
assert Path(output_path).exists() | ||
os.remove(output_path) | ||
|
||
|
||
def test_saves_frames_with_output_path(self, tmp_path): | ||
frames = [ | ||
np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) for _ in range(10) | ||
] | ||
video_output_path = str(tmp_path / "output.mp4") | ||
output_path = save_video(frames, video_output_path) | ||
def test_saves_frames_with_output_path(): | ||
frames = [ | ||
np.random.randint(0, 256, (480, 640, 3), dtype=np.uint8) for _ in range(10) | ||
] | ||
|
||
assert output_path == video_output_path | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
video_output_path = Path(tmp_dir) / "output.mp4" | ||
output_path = save_video(frames, str(video_output_path)) | ||
|
||
assert output_path == str(video_output_path) | ||
assert Path(output_path).exists() | ||
|
||
|
||
def test_save_null_image(): | ||
image = None | ||
try: | ||
save_image(image, "tmp.jpg") | ||
except ValueError as e: | ||
assert str(e) == "The image is not a valid NumPy array with shape (H, W, C)" | ||
|
||
|
||
def test_save_empty_image(): | ||
image = np.zeros((0, 0, 3), dtype=np.uint8) | ||
try: | ||
save_image(image, "tmp.jpg") | ||
except ValueError as e: | ||
assert str(e) == "The image is not a valid NumPy array with shape (H, W, C)" | ||
|
||
|
||
def test_save_null_video(): | ||
frames = None | ||
try: | ||
save_video(frames, "tmp.mp4") | ||
except ValueError as e: | ||
assert str(e) == "Frames must be a list of NumPy arrays" | ||
|
||
def test_save_empty_list(): | ||
frames = [] | ||
try: | ||
save_video(frames, "tmp.mp4") | ||
except ValueError as e: | ||
assert str(e) == "Frames must be a list of NumPy arrays" | ||
|
||
|
||
def test_save_invalid_frame(): | ||
frames = [np.zeros((0, 0, 3), dtype=np.uint8)] | ||
try: | ||
save_video(frames, "tmp.mp4") | ||
except ValueError as e: | ||
assert str(e) == "Frame is not a valid NumPy array with shape (H, W, C)" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters