Skip to content

Commit

Permalink
add params to florence2sam2 model
Browse files Browse the repository at this point in the history
  • Loading branch information
Camilo Iral committed Oct 8, 2024
1 parent cd0bf40 commit 9b0d352
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,11 @@ def florence2_sam2_image(


def florence2_sam2_video_tracking(
prompt: str, frames: List[np.ndarray], chunk_length: Optional[int] = None
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = None,
iou_threshold: Optional[float] = None,
nms_threshold: Optional[float] = None,
) -> List[List[Dict[str, Any]]]:
"""'florence2_sam2_video_tracking' is a tool that can segment and track multiple
entities in a video given a text prompt such as category names or referring
Expand All @@ -486,6 +490,10 @@ def florence2_sam2_video_tracking(
frames (List[np.ndarray]): The list of frames to ground the prompt to.
chunk_length (Optional[int]): The number of frames to re-run florence2 to find
new objects.
iou_threshold (Optional[float]): Value between 0.1 and 1.0.
The IoU threshold value used to compare last_predictions and new_predictions objects.
nms_threshold (Optional[float]): Value between 0.1 and 1.0.
The non-maximum suppression threshold value used to filter the Florencev2 predictions.
Returns:
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label
Expand Down Expand Up @@ -520,6 +528,10 @@ def florence2_sam2_video_tracking(
}
if chunk_length is not None:
payload["chunk_length"] = chunk_length # type: ignore
if iou_threshold is not None:
payload["iou_threshold"] = iou_threshold # type: ignore
if nms_threshold is not None:
payload["nms_threshold"] = nms_threshold # type: ignore
data: Dict[str, Any] = send_inference_request(
payload, "florence2-sam2", files=files, v2=True
)
Expand Down

0 comments on commit 9b0d352

Please sign in to comment.