diff --git a/tests/test_vision_agent.py b/tests/test_vision_agent.py new file mode 100644 index 00000000..98df1f3f --- /dev/null +++ b/tests/test_vision_agent.py @@ -0,0 +1,31 @@ +from vision_agent.agent.vision_agent import sample_n_evenly_spaced + + +def test_sample_n_evenly_spaced_side_cases(): + # Test for empty input + assert sample_n_evenly_spaced([], 0) == [] + assert sample_n_evenly_spaced([], 1) == [] + + # Test for n = 0 + assert sample_n_evenly_spaced([1, 2, 3, 4], 0) == [] + + # Test for n = 1 + assert sample_n_evenly_spaced([1, 2, 3, 4], -1) == [] + assert sample_n_evenly_spaced([1, 2, 3, 4], 5) == [1, 2, 3, 4] + + +def test_sample_n_evenly_spaced_even_cases(): + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 2) == [1, 6] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 3) == [1, 3, 6] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 4) == [1, 3, 4, 6] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 5) == [1, 2, 3, 5, 6] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 6) == [1, 2, 3, 4, 5, 6] + + +def test_sample_n_evenly_spaced_odd_cases(): + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 2) == [1, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 3) == [1, 4, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 4) == [1, 3, 5, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 5) == [1, 3, 4, 5, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 6) == [1, 2, 3, 5, 6, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 7) == [1, 2, 3, 4, 5, 6, 7] diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 93218e6c..514f8fa5 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -366,6 +366,20 @@ def _handle_viz_tools( return image_to_data +def sample_n_evenly_spaced(lst: Sequence, n: int) -> Sequence: + if n <= 0: + return [] + elif len(lst) == 0: + return [] + elif n == 1: + return [lst[0]] + elif n >= len(lst): + return lst + + spacing = (len(lst) - 1) / (n - 1) + return [lst[round(spacing * i)] for i in range(n)] + + def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]: image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: @@ -584,7 +598,7 @@ def chat_with_workflow( visualized_output = visualize_result(all_tool_results) all_tool_results.append({"visualized_output": visualized_output}) if len(visualized_output) > 0: - reflection_images = visualized_output + reflection_images = sample_n_evenly_spaced(visualized_output, 3) elif image is not None: reflection_images = [image] else: diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index a1fcc3c2..cc8861bd 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -9,10 +9,7 @@ import requests from openai import AzureOpenAI, OpenAI -from vision_agent.tools import ( - CHOOSE_PARAMS, - SYSTEM_PROMPT, -) +from vision_agent.tools import CHOOSE_PARAMS, SYSTEM_PROMPT _LOGGER = logging.getLogger(__name__) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 60870b56..10daf7eb 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -12,12 +12,12 @@ GroundingDINO, GroundingSAM, ImageCaption, - ZeroShotCounting, - VisualPromptCounting, - VisualQuestionAnswering, ImageQuestionAnswering, SegArea, SegIoU, Tool, + VisualPromptCounting, + VisualQuestionAnswering, + ZeroShotCounting, register_tool, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 3bf2bfbf..32a998db 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -17,9 +17,9 @@ normalize_bbox, rle_decode, ) +from vision_agent.lmm import OpenAILMM from vision_agent.tools.video import extract_frames_from_video from vision_agent.type_defs import LandingaiAPIKey -from vision_agent.lmm import OpenAILMM _LOGGER = logging.getLogger(__name__) _LND_API_KEY = LandingaiAPIKey().api_key