Skip to content

Commit

Permalink
updated name to florence2_sam2_video_tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 6, 2024
1 parent f1d5f1f commit e414003
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 43 deletions.
3 changes: 1 addition & 2 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
florence2_phrase_grounding,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video,
florence2_sam2_video_tracking,
generate_pose_image,
generate_soft_edge_image,
get_tool_documentation,
Expand All @@ -47,7 +47,6 @@
overlay_heat_map,
overlay_segmentation_masks,
owl_v2_image,
owl_v2_image2,
owl_v2_video,
save_image,
save_json,
Expand Down
56 changes: 15 additions & 41 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,33 +176,6 @@ def owl_v2_image(
{'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
]
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(image)
request_data = {
"prompts": [s.strip() for s in prompt.split(",")],
"image": image_b64,
"confidence": box_threshold,
"function_name": "owl_v2",
}
data: Dict[str, Any] = send_inference_request(request_data, "owlv2", v2=True)
return_data = []
if data is not None:
for elt in data:
return_data.append(
{
"bbox": normalize_bbox(elt["bbox"], image_size), # type: ignore
"label": elt["label"], # type: ignore
"score": round(elt["score"], 2), # type: ignore
}
)
return return_data


def owl_v2_image2(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.30,
) -> List[Dict[str, Any]]:
image_size = image.shape[:2]
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
Expand Down Expand Up @@ -232,10 +205,11 @@ def owl_v2_video(
frames: List[np.ndarray],
box_threshold: float = 0.30,
) -> List[List[Dict[str, Any]]]:
"""'owl_v2_video' is a tool that can detect and count multiple objects given a text
prompt such as category names or referring expressions on videos. The categories in
text prompt are separated by commas. It returns a list of bounding boxes with
normalized coordinates, label names and associated probability scores per frame.
"""'owl_v2_video' will run owl_v2 on each frame of a video. It can detect multiple
objects per frame given a text prompt sucha s a category name or referring
expression. The categories in text prompt are separated by commas. It returns a list
of lists where each inner list contains the score, label, and bounding box of the
detections for that frame.
Parameters:
prompt (str): The prompt to ground to the video.
Expand All @@ -244,7 +218,7 @@ def owl_v2_video(
to 0.30.
Returns:
List[List[Dict[str, Any]]]: A list of dictionaries per frame containing the
List[List[Dict[str, Any]]]: A list of lists of dictionaries containing the
score, label, and bounding box of the detected objects with normalized
coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the
coordinates of the top-left and xmax and ymax are the coordinates of the
Expand Down Expand Up @@ -414,14 +388,14 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]
return return_data


def florence2_sam2_video(
def florence2_sam2_video_tracking(
prompt: str, frames: List[np.ndarray]
) -> List[List[Dict[str, Any]]]:
"""'florence2_sam2_video' is a tool that can segment and track multiple entities
in a video given a text prompt such as category names or referring expressions. You
can optionally separate the categories in the text with commas. It only tracks
entities present in the first frame and only returns segmentation masks. It is
useful for tracking and counting without duplicating counts.
"""'florence2_sam2_video_tracking' is a tool that can segment and track multiple
entities in a video given a text prompt such as category names or referring
expressions. You can optionally separate the categories in the text with commas. It
only tracks entities present in the first frame and only returns segmentation
masks. It is useful for tracking and counting without duplicating counts.
Parameters:
prompt (str): The prompt to ground to the video.
Expand Down Expand Up @@ -456,7 +430,7 @@ def florence2_sam2_video(
files = [("video", buffer_bytes)]
payload = {
"prompts": [s.strip() for s in prompt.split(",")],
"function_name": "florence2_sam2_video",
"function_name": "florence2_sam2_video_tracking",
}
data: Dict[str, Any] = send_inference_request(
payload, "florence2-sam2", files=files, v2=True
Expand Down Expand Up @@ -1933,7 +1907,7 @@ def overlay_counting_results(

FUNCTION_TOOLS = [
owl_v2_image,
# owl_v2_video,
owl_v2_video,
ocr,
clip,
vit_image_classification,
Expand All @@ -1942,7 +1916,7 @@ def overlay_counting_results(
florence2_image_caption,
florence2_ocr,
florence2_sam2_image,
florence2_sam2_video,
florence2_sam2_video_tracking,
florence2_phrase_grounding,
ixc25_image_vqa,
ixc25_video_vqa,
Expand Down

0 comments on commit e414003

Please sign in to comment.