From 7fb2a80ec63a8e79ac45cb15f49be05998f17987 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 24 Apr 2024 15:43:01 -0700 Subject: [PATCH 1/4] add sample for frames --- vision_agent/agent/vision_agent.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 93218e6c..7e87cba0 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -366,6 +366,19 @@ def _handle_viz_tools( return image_to_data +def sample_n_evenly_spaced(lst: Sequence, n: int) -> Sequence: + if n <= 0: + return [] + + if len(lst) <= n: + return lst + + interval = len(lst) // n if len(lst) % 2 == 0 else len(lst) // n + 1 + indices = list(range(len(lst))) + picked_indices = sorted([indices[(i * interval) % len(lst)] for i in range(n)]) + return [lst[i] for i in picked_indices] + + def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]: image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: @@ -584,7 +597,7 @@ def chat_with_workflow( visualized_output = visualize_result(all_tool_results) all_tool_results.append({"visualized_output": visualized_output}) if len(visualized_output) > 0: - reflection_images = visualized_output + reflection_images = sample_n_evenly_spaced(visualized_output, 3) elif image is not None: reflection_images = [image] else: From adad87a5c89e0fafd562c5d51bc975021fe9e344 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 24 Apr 2024 16:25:21 -0700 Subject: [PATCH 2/4] add test cases --- tests/test_vision_agent.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/test_vision_agent.py diff --git a/tests/test_vision_agent.py b/tests/test_vision_agent.py new file mode 100644 index 00000000..78da24e6 --- /dev/null +++ b/tests/test_vision_agent.py @@ -0,0 +1,31 @@ +from vision_agent.agent.vision_agent import sample_n_evenly_spaced + + +def test_sample_n_evenly_spaced_side_cases(): + # Test for empty input + assert sample_n_evenly_spaced([], 0) == [] + assert sample_n_evenly_spaced([], 1) == [] + + # Test for n = 0 + assert sample_n_evenly_spaced([1, 2, 3, 4], 0) == [] + + # Test for n = 1 + assert sample_n_evenly_spaced([1, 2, 3, 4], -1) == [] + assert sample_n_evenly_spaced([1, 2, 3, 4], 5) == [1, 2, 3, 4] + + +def test_sample_n_evenly_spaced_even_cases(): + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 2) == [1, 4] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 3) == [1, 3, 5] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 4) == [1, 2, 3, 4] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 5) == [1, 2, 3, 4, 5] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 6) == [1, 2, 3, 4, 5, 6] + + +def test_sample_n_evenly_spaced_odd_cases(): + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 2) == [1, 5] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 3) == [1, 4, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 4) == [1, 3, 5, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 5) == [1, 2, 3, 5, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 6) == [1, 2, 3, 4, 5, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 7) == [1, 2, 3, 4, 5, 6, 7] From 5e3c8cd44117f51f94e0dcd1e97a5e5a19fe9e5f Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 24 Apr 2024 17:50:29 -0700 Subject: [PATCH 3/4] updated function' --- tests/test_vision_agent.py | 14 +++++++------- vision_agent/agent/vision_agent.py | 14 ++++++++------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/test_vision_agent.py b/tests/test_vision_agent.py index 78da24e6..98df1f3f 100644 --- a/tests/test_vision_agent.py +++ b/tests/test_vision_agent.py @@ -15,17 +15,17 @@ def test_sample_n_evenly_spaced_side_cases(): def test_sample_n_evenly_spaced_even_cases(): - assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 2) == [1, 4] - assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 3) == [1, 3, 5] - assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 4) == [1, 2, 3, 4] - assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 5) == [1, 2, 3, 4, 5] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 2) == [1, 6] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 3) == [1, 3, 6] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 4) == [1, 3, 4, 6] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 5) == [1, 2, 3, 5, 6] assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 6) == [1, 2, 3, 4, 5, 6] def test_sample_n_evenly_spaced_odd_cases(): - assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 2) == [1, 5] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 2) == [1, 7] assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 3) == [1, 4, 7] assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 4) == [1, 3, 5, 7] - assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 5) == [1, 2, 3, 5, 7] - assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 6) == [1, 2, 3, 4, 5, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 5) == [1, 3, 4, 5, 7] + assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 6) == [1, 2, 3, 5, 6, 7] assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 7) == [1, 2, 3, 4, 5, 6, 7] diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 7e87cba0..e6f5818d 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -369,14 +369,16 @@ def _handle_viz_tools( def sample_n_evenly_spaced(lst: Sequence, n: int) -> Sequence: if n <= 0: return [] - - if len(lst) <= n: + elif len(lst) == 0: + return [] + elif n == 1: + return [lst[0]] + elif n >= len(lst): return lst - interval = len(lst) // n if len(lst) % 2 == 0 else len(lst) // n + 1 - indices = list(range(len(lst))) - picked_indices = sorted([indices[(i * interval) % len(lst)] for i in range(n)]) - return [lst[i] for i in picked_indices] + spacing = (len(lst) - 1) / (n - 1) + return [lst[round(spacing * i)] for i in range(n)] + def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]: From 3601d86c246f374f7f4052ed14b609c74cfc2931 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 24 Apr 2024 18:11:30 -0700 Subject: [PATCH 4/4] black and isort --- vision_agent/agent/vision_agent.py | 1 - vision_agent/lmm/lmm.py | 5 +---- vision_agent/tools/__init__.py | 6 +++--- vision_agent/tools/tools.py | 2 +- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index e6f5818d..514f8fa5 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -380,7 +380,6 @@ def sample_n_evenly_spaced(lst: Sequence, n: int) -> Sequence: return [lst[round(spacing * i)] for i in range(n)] - def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]: image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index a1fcc3c2..cc8861bd 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -9,10 +9,7 @@ import requests from openai import AzureOpenAI, OpenAI -from vision_agent.tools import ( - CHOOSE_PARAMS, - SYSTEM_PROMPT, -) +from vision_agent.tools import CHOOSE_PARAMS, SYSTEM_PROMPT _LOGGER = logging.getLogger(__name__) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 60870b56..10daf7eb 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -12,12 +12,12 @@ GroundingDINO, GroundingSAM, ImageCaption, - ZeroShotCounting, - VisualPromptCounting, - VisualQuestionAnswering, ImageQuestionAnswering, SegArea, SegIoU, Tool, + VisualPromptCounting, + VisualQuestionAnswering, + ZeroShotCounting, register_tool, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 3bf2bfbf..32a998db 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -17,9 +17,9 @@ normalize_bbox, rle_decode, ) +from vision_agent.lmm import OpenAILMM from vision_agent.tools.video import extract_frames_from_video from vision_agent.type_defs import LandingaiAPIKey -from vision_agent.lmm import OpenAILMM _LOGGER = logging.getLogger(__name__) _LND_API_KEY = LandingaiAPIKey().api_key