Skip to content

Commit

Permalink
Limit number of frames sent to reflection (#64)
Browse files Browse the repository at this point in the history
* add sample for frames

* add test cases

* updated function'

* black and isort
  • Loading branch information
dillonalaird authored Apr 25, 2024
1 parent 63c90a1 commit ad02c67
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 9 deletions.
31 changes: 31 additions & 0 deletions tests/test_vision_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from vision_agent.agent.vision_agent import sample_n_evenly_spaced


def test_sample_n_evenly_spaced_side_cases():
# Test for empty input
assert sample_n_evenly_spaced([], 0) == []
assert sample_n_evenly_spaced([], 1) == []

# Test for n = 0
assert sample_n_evenly_spaced([1, 2, 3, 4], 0) == []

# Test for n = 1
assert sample_n_evenly_spaced([1, 2, 3, 4], -1) == []
assert sample_n_evenly_spaced([1, 2, 3, 4], 5) == [1, 2, 3, 4]


def test_sample_n_evenly_spaced_even_cases():
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 2) == [1, 6]
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 3) == [1, 3, 6]
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 4) == [1, 3, 4, 6]
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 5) == [1, 2, 3, 5, 6]
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6], 6) == [1, 2, 3, 4, 5, 6]


def test_sample_n_evenly_spaced_odd_cases():
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 2) == [1, 7]
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 3) == [1, 4, 7]
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 4) == [1, 3, 5, 7]
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 5) == [1, 3, 4, 5, 7]
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 6) == [1, 2, 3, 5, 6, 7]
assert sample_n_evenly_spaced([1, 2, 3, 4, 5, 6, 7], 7) == [1, 2, 3, 4, 5, 6, 7]
16 changes: 15 additions & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,20 @@ def _handle_viz_tools(
return image_to_data


def sample_n_evenly_spaced(lst: Sequence, n: int) -> Sequence:
if n <= 0:
return []
elif len(lst) == 0:
return []
elif n == 1:
return [lst[0]]
elif n >= len(lst):
return lst

spacing = (len(lst) - 1) / (n - 1)
return [lst[round(spacing * i)] for i in range(n)]


def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
Expand Down Expand Up @@ -584,7 +598,7 @@ def chat_with_workflow(
visualized_output = visualize_result(all_tool_results)
all_tool_results.append({"visualized_output": visualized_output})
if len(visualized_output) > 0:
reflection_images = visualized_output
reflection_images = sample_n_evenly_spaced(visualized_output, 3)
elif image is not None:
reflection_images = [image]
else:
Expand Down
5 changes: 1 addition & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import requests
from openai import AzureOpenAI, OpenAI

from vision_agent.tools import (
CHOOSE_PARAMS,
SYSTEM_PROMPT,
)
from vision_agent.tools import CHOOSE_PARAMS, SYSTEM_PROMPT

_LOGGER = logging.getLogger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
GroundingDINO,
GroundingSAM,
ImageCaption,
ZeroShotCounting,
VisualPromptCounting,
VisualQuestionAnswering,
ImageQuestionAnswering,
SegArea,
SegIoU,
Tool,
VisualPromptCounting,
VisualQuestionAnswering,
ZeroShotCounting,
register_tool,
)
2 changes: 1 addition & 1 deletion vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
normalize_bbox,
rle_decode,
)
from vision_agent.lmm import OpenAILMM
from vision_agent.tools.video import extract_frames_from_video
from vision_agent.type_defs import LandingaiAPIKey
from vision_agent.lmm import OpenAILMM

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
Expand Down

0 comments on commit ad02c67

Please sign in to comment.