Skip to content

Commit

Permalink
fix video
Browse files Browse the repository at this point in the history
  • Loading branch information
Dayof committed Oct 2, 2024
1 parent 0b1c886 commit 6aa3a24
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
4 changes: 2 additions & 2 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_florence2_phrase_grounding_video():
frames=frames,
)
assert len(result) == 10
assert 24 <= len([res["label"] for res in result[0]]) <= 26
assert 2 <= len([res["label"] for res in result[0]]) <= 26


def test_florence2_phrase_grounding_video_fine_tune_id():
Expand All @@ -138,7 +138,7 @@ def test_florence2_phrase_grounding_video_fine_tune_id():
fine_tune_id=FINE_TUNE_ID,
)
assert len(result) == 10
assert 24 <= len([res["label"] for res in result[0]]) <= 26
assert 16 <= len([res["label"] for res in result[0]]) <= 26


def test_template_match():
Expand Down
9 changes: 7 additions & 2 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def send_inference_request(
files: Optional[List[Tuple[Any, ...]]] = None,
v2: bool = False,
metadata_payload: Optional[Dict[str, Any]] = None,
is_form: bool = False,
) -> Any:
# TODO: runtime_tag and function_name should be metadata_payload and not included
# in the service payload
Expand Down Expand Up @@ -64,7 +65,7 @@ def send_inference_request(
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]

response = _call_post(url, payload, session, files, function_name)
response = _call_post(url, payload, session, files, function_name, is_form)

# TODO: consider making the response schema the same between below two sources
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
Expand All @@ -75,6 +76,7 @@ def send_task_inference_request(
task_name: str,
files: Optional[List[Tuple[Any, ...]]] = None,
metadata: Optional[Dict[str, Any]] = None,
is_form: bool = False,
) -> Any:
url = f"{_LND_API_URL_v2}/{task_name}"
headers = {"apikey": _LND_API_KEY}
Expand All @@ -87,7 +89,7 @@ def send_task_inference_request(
function_name = "unknown"
if metadata is not None and "function_name" in metadata:
function_name = metadata["function_name"]
response = _call_post(url, payload, session, files, function_name)
response = _call_post(url, payload, session, files, function_name, is_form)
return response["data"]


Expand Down Expand Up @@ -203,13 +205,16 @@ def _call_post(
session: Session,
files: Optional[List[Tuple[Any, ...]]] = None,
function_name: str = "unknown",
is_form: bool = False,
) -> Any:
files_in_b64 = None
if files:
files_in_b64 = [(file[0], b64encode(file[1]).decode("utf-8")) for file in files]
try:
if files is not None:
response = session.post(url, data=payload, files=files)
elif is_form:
response = session.post(url, data=payload)
else:
response = session.post(url, json=payload)

Expand Down
11 changes: 5 additions & 6 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def owl_v2_image(
data,
"florence2-ft",
v2=True,
is_form=True,
metadata_payload={"function_name": "owl_v2_image"},
)
# get the first frame
Expand Down Expand Up @@ -432,6 +433,7 @@ def florence2_sam2_image(
req_data,
"florence2-ft",
v2=True,
is_form=True,
metadata_payload={"function_name": "florence2_sam2_image"},
)
# get the first frame
Expand Down Expand Up @@ -1193,6 +1195,7 @@ def florence2_phrase_grounding_image(
data,
"florence2-ft",
v2=True,
is_form=True,
metadata_payload={"function_name": "florence2_phrase_grounding_image"},
)
# get the first frame
Expand Down Expand Up @@ -1268,18 +1271,14 @@ def florence2_phrase_grounding_video(
)

data_obj = Florence2FtRequest(
video=buffer_bytes,
task=PromptTask.PHRASE_GROUNDING,
prompt=prompt,
job_id=UUID(fine_tune_id),
)
data = data_obj.model_dump(by_alias=True, exclude_none=True)
else:
data_obj = Florence2FtRequest(
video=buffer_bytes, task=PromptTask.PHRASE_GROUNDING, prompt=prompt
)
data = data_obj.model_dump(by_alias=True, exclude_none=True)
data_obj = Florence2FtRequest(task=PromptTask.PHRASE_GROUNDING, prompt=prompt)

data = data_obj.model_dump(by_alias=True, exclude_none=True, mode="json")
detections = send_inference_request(
data,
"florence2-ft",
Expand Down

0 comments on commit 6aa3a24

Please sign in to comment.