diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 71646c45..ed48fafe 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -473,7 +473,11 @@ def florence2_sam2_image( def florence2_sam2_video_tracking( - prompt: str, frames: List[np.ndarray], chunk_length: Optional[int] = None + prompt: str, + frames: List[np.ndarray], + chunk_length: Optional[int] = None, + iou_threshold: Optional[float] = None, + nms_threshold: Optional[float] = None, ) -> List[List[Dict[str, Any]]]: """'florence2_sam2_video_tracking' is a tool that can segment and track multiple entities in a video given a text prompt such as category names or referring @@ -486,6 +490,10 @@ def florence2_sam2_video_tracking( frames (List[np.ndarray]): The list of frames to ground the prompt to. chunk_length (Optional[int]): The number of frames to re-run florence2 to find new objects. + iou_threshold (Optional[float]): Value between 0.1 and 1.0. + The IoU threshold value used to compare last_predictions and new_predictions objects. + nms_threshold (Optional[float]): Value between 0.1 and 1.0. + The non-maximum suppression threshold value used to filter the Florencev2 predictions. Returns: List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label @@ -520,6 +528,10 @@ def florence2_sam2_video_tracking( } if chunk_length is not None: payload["chunk_length"] = chunk_length # type: ignore + if iou_threshold is not None: + payload["iou_threshold"] = iou_threshold # type: ignore + if nms_threshold is not None: + payload["nms_threshold"] = nms_threshold # type: ignore data: Dict[str, Any] = send_inference_request( payload, "florence2-sam2", files=files, v2=True )