Skip to content

Commit

Permalink
Improve video usage (#229)
Browse files Browse the repository at this point in the history
* add owlv2 video

* update doc extract_frames to include urls

* fix countgd return decimal places

* fixed return types

* prompt tests to run faster

* testing owlv2_video

* updated name to florence2_sam2_video_tracking

* lowered threshold

* ran isort

* fix mypy errors

* fix tests'

* fix tests
  • Loading branch information
dillonalaird authored Sep 6, 2024
1 parent 272e845 commit c438d9b
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 60 deletions.
30 changes: 22 additions & 8 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
detr_segmentation,
dpt_hybrid_midas,
florence2_image_caption,
florence2_phrase_grounding,
florence2_ocr,
florence2_phrase_grounding,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video,
florence2_sam2_video_tracking,
generate_pose_image,
generate_soft_edge_image,
git_vqa_v2,
Expand All @@ -25,7 +25,8 @@
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
owl_v2,
owl_v2_image,
owl_v2_video,
template_match,
vit_image_classification,
vit_nsfw_classification,
Expand Down Expand Up @@ -53,14 +54,27 @@ def test_grounding_dino_tiny():
assert [res["label"] for res in result] == ["coin"] * 24


def test_owl():
def test_owl_v2_image():
img = ski.data.coins()
result = owl_v2(
result = owl_v2_image(
prompt="coin",
image=img,
)
assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25
assert 24 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)


def test_owl_v2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = owl_v2_video(
prompt="coin",
frames=frames,
)

assert len(result) == 10
assert 24 <= len([res["label"] for res in result[0]]) <= 26


def test_object_detection():
Expand Down Expand Up @@ -108,7 +122,7 @@ def test_florence2_sam2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = florence2_sam2_video(
result = florence2_sam2_video_tracking(
prompt="coin",
frames=frames,
)
Expand Down
5 changes: 1 addition & 4 deletions tests/integration_dev/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import skimage as ski

from vision_agent.tools import (
countgd_counting,
countgd_example_based_counting,
)
from vision_agent.tools import countgd_counting, countgd_example_based_counting


def test_countgd_counting() -> None:
Expand Down
48 changes: 41 additions & 7 deletions vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,64 @@
2. Create a dictionary where the keys are the tool name and the values are the tool outputs. Remove numpy arrays from the printed dictionary.
3. Your test case MUST run only on the given images which are {media}
4. Print this final dictionary.
5. For video input, sample at 1 FPS and use the first 10 frames only to reduce processing time.
**Example**:
--- EXAMPLE1 ---
plan1:
- Load the image from the provided file path 'image.jpg'.
- Use the 'owl_v2' tool with the prompt 'person' to detect and count the number of people in the image.
- Use the 'owl_v2_image' tool with the prompt 'person' to detect and count the number of people in the image.
plan2:
- Load the image from the provided file path 'image.jpg'.
- Use the 'grounding_sam' tool with the prompt 'person' to detect and count the number of people in the image.
- Use the 'florence2_sam2_image' tool with the prompt 'person' to detect and count the number of people in the image.
- Count the number of detected objects labeled as 'person'.
plan3:
- Load the image from the provided file path 'image.jpg'.
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
```python
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
from vision_agent.tools import load_image, owl_v2_image, florence2_sam2_image, countgd_counting
image = load_image("image.jpg")
owl_v2_out = owl_v2("person", image)
owl_v2_out = owl_v2_image("person", image)
gsam_out = grounding_sam("person", image)
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
f2s2_out = florence2_sam2_image("person", image)
# strip out the masks from the output becuase they don't provide useful information when printed
f2s2_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in f2s2_out]
cgd_out = countgd_counting(image)
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
final_out = {{"owl_v2_image": owl_v2_out, "florence2_sam2_image": f2s2, "countgd_counting": cgd_out}}
print(final_out)
--- EXAMPLE2 ---
plan1:
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool.
- Use the 'owl_v2_image' tool with the prompt 'person' to detect where the people are in the video.
plan2:
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool.
- Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video.
plan3:
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool.
- Use the 'countgd_counting' tool with the prompt 'person' to detect where the people are in the video.
```python
from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, countgd_counting
# sample at 1 FPS and use the first 10 frames to reduce processing time
frames = extract_frames("video.mp4", 1)
frames = [f[0] for f in frames][:10]
# plan1
owl_v2_out = [owl_v2_image("person", f) for f in frames]
# plan2
florence2_out = [florence2_phrase_grounding("person", f) for f in frames]
# plan3
countgd_out = [countgd_counting(f) for f in frames]
final_out = {{"owl_v2_image": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
print(final_out)
```
"""
Expand Down
5 changes: 3 additions & 2 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
florence2_phrase_grounding,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video,
florence2_sam2_video_tracking,
generate_pose_image,
generate_soft_edge_image,
get_tool_documentation,
Expand All @@ -46,7 +46,8 @@
overlay_counting_results,
overlay_heat_map,
overlay_segmentation_masks,
owl_v2,
owl_v2_image,
owl_v2_video,
save_image,
save_json,
save_video,
Expand Down
Loading

0 comments on commit c438d9b

Please sign in to comment.