Skip to content

Commit

Permalink
feat: florence2-ft video support (#253)
Browse files Browse the repository at this point in the history
* get first frame

* adjust postprocessing

* rle enconding

* add video support

* use video endpoint for florence2 instead of ft endpoint

* fix video-temporal-localization

* fix countgd

* hide florence2_phrase_grounding_video

---------

Co-authored-by: Dillon Laird <[email protected]>
  • Loading branch information
Dayof and dillonalaird authored Oct 4, 2024
1 parent b74e089 commit 64095a9
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 117 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ jobs:
- name: Test with pytest
run: |
poetry run pytest -v tests/integ
- name: Test with pytest, dev env
run: |
LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev
release:
name: Release
Expand Down
54 changes: 49 additions & 5 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
dpt_hybrid_midas,
florence2_image_caption,
florence2_ocr,
florence2_phrase_grounding,
florence2_phrase_grounding_image,
# florence2_phrase_grounding_video,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video_tracking,
Expand All @@ -31,6 +32,8 @@
template_match,
vit_image_classification,
vit_nsfw_classification,
countgd_counting,
countgd_example_based_counting,
)

FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da"
Expand Down Expand Up @@ -92,19 +95,19 @@ def test_owl_v2_video():
assert 24 <= len([res["label"] for res in result[0]]) <= 26


def test_florence2_phrase_grounding():
def test_florence2_phrase_grounding_image():
img = ski.data.coins()
result = florence2_phrase_grounding(
result = florence2_phrase_grounding_image(
image=img,
prompt="coin",
)
assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25


def test_florence2_phrase_grounding_fine_tune_id():
def test_florence2_phrase_grounding_image_fine_tune_id():
img = ski.data.coins()
result = florence2_phrase_grounding(
result = florence2_phrase_grounding_image(
prompt="coin",
image=img,
fine_tune_id=FINE_TUNE_ID,
Expand All @@ -114,6 +117,32 @@ def test_florence2_phrase_grounding_fine_tune_id():
assert [res["label"] for res in result] == ["coin"] * len(result)


# def test_florence2_phrase_grounding_video():
# frames = [
# np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
# ]
# result = florence2_phrase_grounding_video(
# prompt="coin",
# frames=frames,
# )
# assert len(result) == 10
# assert 2 <= len([res["label"] for res in result[0]]) <= 26


# def test_florence2_phrase_grounding_video_fine_tune_id():
# frames = [
# np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
# ]
# # this calls a fine-tuned florence2 model which is going to be worse at this task
# result = florence2_phrase_grounding_video(
# prompt="coin",
# frames=frames,
# fine_tune_id=FINE_TUNE_ID,
# )
# assert len(result) == 10
# assert 16 <= len([res["label"] for res in result[0]]) <= 26


def test_template_match():
img = ski.data.coins()
result = template_match(
Expand Down Expand Up @@ -360,3 +389,18 @@ def test_generate_hed():
)

assert result.shape == img.shape


def test_countgd_counting() -> None:
img = ski.data.coins()
result = countgd_counting(image=img, prompt="coin")
assert len(result) == 24


def test_countgd_example_based_counting() -> None:
img = ski.data.coins()
result = countgd_example_based_counting(
visual_prompts=[[85, 106, 122, 145]],
image=img,
)
assert len(result) == 24
Empty file removed tests/integration_dev/__init__.py
Empty file.
18 changes: 0 additions & 18 deletions tests/integration_dev/test_tools.py

This file was deleted.

16 changes: 8 additions & 8 deletions tests/unit/test_meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,41 +33,41 @@ def test_use_object_detection_fine_tuning_none():

def test_use_object_detection_fine_tuning():
artifacts = Artifacts("test")
code = """florence2_phrase_grounding('one', image1)
code = """florence2_phrase_grounding_image('one', image1)
owl_v2_image('two', image2)
florence2_sam2_image('three', image3)"""
expected_code = """florence2_phrase_grounding("one", image1, "123")
expected_code = """florence2_phrase_grounding_image("one", image1, "123")
owl_v2_image("two", image2, "123")
florence2_sam2_image("three", image3, "123")"""
artifacts["code"] = code

output = use_object_detection_fine_tuning(artifacts, "code", "123")
assert 'florence2_phrase_grounding("one", image1, "123")' in output
assert 'florence2_phrase_grounding_image("one", image1, "123")' in output
assert 'owl_v2_image("two", image2, "123")' in output
assert 'florence2_sam2_image("three", image3, "123")' in output
assert artifacts["code"] == expected_code


def test_use_object_detection_fine_tuning_twice():
artifacts = Artifacts("test")
code = """florence2_phrase_grounding('one', image1)
code = """florence2_phrase_grounding_image('one', image1)
owl_v2_image('two', image2)
florence2_sam2_image('three', image3)"""
expected_code1 = """florence2_phrase_grounding("one", image1, "123")
expected_code1 = """florence2_phrase_grounding_image("one", image1, "123")
owl_v2_image("two", image2, "123")
florence2_sam2_image("three", image3, "123")"""
expected_code2 = """florence2_phrase_grounding("one", image1, "456")
expected_code2 = """florence2_phrase_grounding_image("one", image1, "456")
owl_v2_image("two", image2, "456")
florence2_sam2_image("three", image3, "456")"""
artifacts["code"] = code
output = use_object_detection_fine_tuning(artifacts, "code", "123")
assert 'florence2_phrase_grounding("one", image1, "123")' in output
assert 'florence2_phrase_grounding_image("one", image1, "123")' in output
assert 'owl_v2_image("two", image2, "123")' in output
assert 'florence2_sam2_image("three", image3, "123")' in output
assert artifacts["code"] == expected_code1

output = use_object_detection_fine_tuning(artifacts, "code", "456")
assert 'florence2_phrase_grounding("one", image1, "456")' in output
assert 'florence2_phrase_grounding_image("one", image1, "456")' in output
assert 'owl_v2_image("two", image2, "456")' in output
assert 'florence2_sam2_image("three", image3, "456")' in output
assert artifacts["code"] == expected_code2
10 changes: 5 additions & 5 deletions vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@
- Use the 'owl_v2_video' tool with the prompt 'person' to detect where the people are in the video.
plan2:
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames_and_timestamps' tool.
- Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video.
- Use the 'florence2_phrase_grounding_image' tool with the prompt 'person' to detect where the people are in the video.
plan3:
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames_and_timestamps' tool.
- Use the 'florence2_sam2_video_tracking' tool with the prompt 'person' to detect where the people are in the video.
```python
import numpy as np
from vision_agent.tools import extract_frames_and_timestamps, owl_v2_video, florence2_phrase_grounding, florence2_sam2_video_tracking
from vision_agent.tools import extract_frames_and_timestamps, owl_v2_video, florence2_phrase_grounding_image, florence2_sam2_video_tracking
# sample at 1 FPS and use the first 10 frames to reduce processing time
frames = extract_frames_and_timestamps("video.mp4", 1)
Expand Down Expand Up @@ -143,7 +143,7 @@ def get_counts(preds):
owl_v2_counts = get_counts(owl_v2_out)
# plan2
florence2_out = [florence2_phrase_grounding("person", f) for f in frames]
florence2_out = [florence2_phrase_grounding_image("person", f) for f in frames]
florence2_counts = get_counts(florence2_out)
# plan3
Expand All @@ -153,13 +153,13 @@ def get_counts(preds):
final_out = {{
"owl_v2_video": owl_v2_out,
"florence2_phrase_grounding": florence2_out,
"florence2_phrase_grounding_image": florence2_out,
"florence2_sam2_video_tracking": f2s2_out,
}}
counts = {{
"owl_v2_video": owl_v2_counts,
"florence2_phrase_grounding": florence2_counts,
"florence2_phrase_grounding_image": florence2_counts,
"florence2_sam2_video_tracking": f2s2_counts,
}}
Expand Down
14 changes: 7 additions & 7 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@
OBSERVATION:
[Artifact code.py]
0|from vision_agent.tools import load_image, florence2_phrase_grounding, closest_box_distance, overlay_bounding_boxes, save_image
0|from vision_agent.tools import load_image, florence2_phrase_grounding_image, closest_box_distance, overlay_bounding_boxes, save_image
1|def count_workers_with_helmets(image_path: str, output_path: str):
2| image = load_image(image_path)
3| detections = florence2_phrase_grounding("worker, helmet", image)
3| detections = florence2_phrase_grounding_image("worker, helmet", image)
4| workers = [d for d in detections if d['label'] == 'worker']
5| helmets = [d for d in detections if d['label'] == 'helmet']
6| count = 0
Expand Down Expand Up @@ -166,18 +166,18 @@
OBSERVATION:
[Fine tuning id: 23b3b022-5ebf-4798-9373-20ef36429abf]
AGENT: {"thoughts": "The model has finished fine tuning, I will now replace the original florence2_phrase_grounding call with the fine tuning id.", "response": "I will not update the code to use the fine tuned model. <execute_python>use_object_detection_fine_tuning(artifacts, "code.py", "23b3b022-5ebf-4798-9373-20ef36429abf")</execute_python>", "let_user_respond": false}
AGENT: {"thoughts": "The model has finished fine tuning, I will now replace the original florence2_phrase_grounding_image call with the fine tuning id.", "response": "I will not update the code to use the fine tuned model. <execute_python>use_object_detection_fine_tuning(artifacts, "code.py", "23b3b022-5ebf-4798-9373-20ef36429abf")</execute_python>", "let_user_respond": false}
OBSERVATION:
[Artifact code.py edits]
---
+++
@@ -1,7 +1,7 @@
from vision_agent.tools import load_image, florence2_phrase_grounding, closest_box_distance, overlay_bounding_boxes, save_image
from vision_agent.tools import load_image, florence2_phrase_grounding_image, closest_box_distance, overlay_bounding_boxes, save_image
def count_workers_with_helmets(image_path: str, output_path: str):
image = load_image(image_path)
- detections = florence2_phrase_grounding("worker, helmet", image)
+ detections = florence2_phrase_grounding("worker, helmet", image, "23b3b022-5ebf-4798-9373-20ef36429abf")
- detections = florence2_phrase_grounding_image("worker, helmet", image)
+ detections = florence2_phrase_grounding_image("worker, helmet", image, "23b3b022-5ebf-4798-9373-20ef36429abf")
workers = [d for d in detections if d['label'] == 'worker']
helmets = [d for d in detections if d['label'] == 'helmet']
count = 0
Expand All @@ -189,5 +189,5 @@ def count_workers_with_helmets(image_path: str, output_path: str):
----- stdout -----
3
AGENT: {"thoughts": "Three workers with helmets are detected, I will show this to the user and ask them if the result looks good.", "response": "I have updated the code using the fine tuned florence2_phrase_grounding model to count the workers wearing helmets in code.py and saved the visualization under 'workers_viz.png'.", "let_user_respond": true}
AGENT: {"thoughts": "Three workers with helmets are detected, I will show this to the user and ask them if the result looks good.", "response": "I have updated the code using the fine tuned florence2_phrase_grounding_image model to count the workers wearing helmets in code.py and saved the visualization under 'workers_viz.png'.", "let_user_respond": true}
"""
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
extract_frames_and_timestamps,
florence2_image_caption,
florence2_ocr,
florence2_phrase_grounding,
florence2_phrase_grounding_image,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video_tracking,
Expand Down
8 changes: 6 additions & 2 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,12 @@ def use_object_detection_fine_tuning(

patterns_with_fine_tune_id = [
(
r'florence2_phrase_grounding\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'florence2_phrase_grounding("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
r'florence2_phrase_grounding_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'florence2_phrase_grounding_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
(
r'florence2_phrase_grounding_video\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'florence2_phrase_grounding_video("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
(
r'owl_v2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
Expand Down
13 changes: 9 additions & 4 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import inspect
import logging
import os
from base64 import b64encode
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple

Expand Down Expand Up @@ -37,8 +37,9 @@ def send_inference_request(
files: Optional[List[Tuple[Any, ...]]] = None,
v2: bool = False,
metadata_payload: Optional[Dict[str, Any]] = None,
is_form: bool = False,
) -> Any:
# TODO: runtime_tag and function_name should be metadata_payload and now included
# TODO: runtime_tag and function_name should be metadata_payload and not included
# in the service payload
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag
Expand All @@ -64,7 +65,7 @@ def send_inference_request(
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]

response = _call_post(url, payload, session, files, function_name)
response = _call_post(url, payload, session, files, function_name, is_form)

# TODO: consider making the response schema the same between below two sources
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
Expand All @@ -75,6 +76,7 @@ def send_task_inference_request(
task_name: str,
files: Optional[List[Tuple[Any, ...]]] = None,
metadata: Optional[Dict[str, Any]] = None,
is_form: bool = False,
) -> Any:
url = f"{_LND_API_URL_v2}/{task_name}"
headers = {"apikey": _LND_API_KEY}
Expand All @@ -87,7 +89,7 @@ def send_task_inference_request(
function_name = "unknown"
if metadata is not None and "function_name" in metadata:
function_name = metadata["function_name"]
response = _call_post(url, payload, session, files, function_name)
response = _call_post(url, payload, session, files, function_name, is_form)
return response["data"]


Expand Down Expand Up @@ -203,13 +205,16 @@ def _call_post(
session: Session,
files: Optional[List[Tuple[Any, ...]]] = None,
function_name: str = "unknown",
is_form: bool = False,
) -> Any:
files_in_b64 = None
if files:
files_in_b64 = [(file[0], b64encode(file[1]).decode("utf-8")) for file in files]
try:
if files is not None:
response = session.post(url, data=payload, files=files)
elif is_form:
response = session.post(url, data=payload)
else:
response = session.post(url, json=payload)

Expand Down
Loading

0 comments on commit 64095a9

Please sign in to comment.