Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve video usage #229

Merged
merged 12 commits into from
Sep 6, 2024
30 changes: 22 additions & 8 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
detr_segmentation,
dpt_hybrid_midas,
florence2_image_caption,
florence2_phrase_grounding,
florence2_ocr,
florence2_phrase_grounding,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video,
florence2_sam2_video_tracking,
generate_pose_image,
generate_soft_edge_image,
git_vqa_v2,
Expand All @@ -25,7 +25,8 @@
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
owl_v2,
owl_v2_image,
owl_v2_video,
template_match,
vit_image_classification,
vit_nsfw_classification,
Expand Down Expand Up @@ -53,14 +54,27 @@ def test_grounding_dino_tiny():
assert [res["label"] for res in result] == ["coin"] * 24


def test_owl():
def test_owl_v2_image():
img = ski.data.coins()
result = owl_v2(
result = owl_v2_image(
prompt="coin",
image=img,
)
assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25
assert 24 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)


def test_owl_v2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = owl_v2_video(
prompt="coin",
frames=frames,
)

assert len(result) == 10
assert 24 <= len([res["label"] for res in result[0]]) <= 26


def test_object_detection():
Expand Down Expand Up @@ -108,7 +122,7 @@ def test_florence2_sam2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = florence2_sam2_video(
result = florence2_sam2_video_tracking(
prompt="coin",
frames=frames,
)
Expand Down
5 changes: 1 addition & 4 deletions tests/integration_dev/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import skimage as ski

from vision_agent.tools import (
countgd_counting,
countgd_example_based_counting,
)
from vision_agent.tools import countgd_counting, countgd_example_based_counting


def test_countgd_counting() -> None:
Expand Down
48 changes: 41 additions & 7 deletions vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,64 @@
2. Create a dictionary where the keys are the tool name and the values are the tool outputs. Remove numpy arrays from the printed dictionary.
3. Your test case MUST run only on the given images which are {media}
4. Print this final dictionary.
5. For video input, sample at 1 FPS and use the first 10 frames only to reduce processing time.

**Example**:
--- EXAMPLE1 ---
plan1:
- Load the image from the provided file path 'image.jpg'.
- Use the 'owl_v2' tool with the prompt 'person' to detect and count the number of people in the image.
- Use the 'owl_v2_image' tool with the prompt 'person' to detect and count the number of people in the image.
plan2:
- Load the image from the provided file path 'image.jpg'.
- Use the 'grounding_sam' tool with the prompt 'person' to detect and count the number of people in the image.
- Use the 'florence2_sam2_image' tool with the prompt 'person' to detect and count the number of people in the image.
- Count the number of detected objects labeled as 'person'.
plan3:
- Load the image from the provided file path 'image.jpg'.
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.

```python
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
from vision_agent.tools import load_image, owl_v2_image, florence2_sam2_image, countgd_counting
image = load_image("image.jpg")
owl_v2_out = owl_v2("person", image)
owl_v2_out = owl_v2_image("person", image)

gsam_out = grounding_sam("person", image)
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
f2s2_out = florence2_sam2_image("person", image)
# strip out the masks from the output becuase they don't provide useful information when printed
f2s2_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in f2s2_out]

cgd_out = countgd_counting(image)

final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
final_out = {{"owl_v2_image": owl_v2_out, "florence2_sam2_image": f2s2, "countgd_counting": cgd_out}}
print(final_out)

--- EXAMPLE2 ---
plan1:
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool.
- Use the 'owl_v2_image' tool with the prompt 'person' to detect where the people are in the video.
plan2:
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool.
- Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video.
plan3:
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool.
- Use the 'countgd_counting' tool with the prompt 'person' to detect where the people are in the video.


```python
from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, countgd_counting

# sample at 1 FPS and use the first 10 frames to reduce processing time
frames = extract_frames("video.mp4", 1)
frames = [f[0] for f in frames][:10]

# plan1
owl_v2_out = [owl_v2_image("person", f) for f in frames]

# plan2
florence2_out = [florence2_phrase_grounding("person", f) for f in frames]

# plan3
countgd_out = [countgd_counting(f) for f in frames]

final_out = {{"owl_v2_image": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
print(final_out)
```
"""
Expand Down
5 changes: 3 additions & 2 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
florence2_phrase_grounding,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video,
florence2_sam2_video_tracking,
generate_pose_image,
generate_soft_edge_image,
get_tool_documentation,
Expand All @@ -46,7 +46,8 @@
overlay_counting_results,
overlay_heat_map,
overlay_segmentation_masks,
owl_v2,
owl_v2_image,
owl_v2_video,
save_image,
save_json,
save_video,
Expand Down
Loading
Loading