diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index a401fb46..22453224 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -21,7 +21,7 @@ depth_anything_v2, detr_segmentation, dpt_hybrid_midas, - extract_frames, + extract_frames_and_timestamps, florence2_image_caption, florence2_ocr, florence2_phrase_grounding, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 0e58049a..309e9ba2 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -240,10 +240,10 @@ def owl_v2_video( box_threshold: float = 0.10, ) -> List[List[Dict[str, Any]]]: """'owl_v2_video' will run owl_v2 on each frame of a video. It can detect multiple - objects per frame given a text prompt sucha s a category name or referring - expression. The categories in text prompt are separated by commas. It returns a list - of lists where each inner list contains the score, label, and bounding box of the - detections for that frame. + objects indepdently per frame given a text prompt such as a category name or + referring expression but does not track objects across frames. The categories in + text prompt are separated by commas. It returns a list of lists where each inner + list contains the score, label, and bounding box of the detections for that frame. Parameters: prompt (str): The prompt to ground to the video. @@ -461,18 +461,19 @@ def florence2_sam2_image( def florence2_sam2_video_tracking( - prompt: str, frames: List[np.ndarray] + prompt: str, frames: List[np.ndarray], chunk_length: Optional[int] = None ) -> List[List[Dict[str, Any]]]: """'florence2_sam2_video_tracking' is a tool that can segment and track multiple entities in a video given a text prompt such as category names or referring expressions. You can optionally separate the categories in the text with commas. It - only tracks entities present in the first frame and only returns segmentation - masks. It is useful for tracking and counting without duplicating counts if they - appear in the first frame, always outputs scores of 1.0. + can find new objects every 'chunk_length' frames and is useful for tracking and + counting without duplicating counts and always outputs scores of 1.0. Parameters: prompt (str): The prompt to ground to the video. frames (List[np.ndarray]): The list of frames to ground the prompt to. + chunk_length (Optional[int]): The number of frames to re-run florence2 to find + new objects. Returns: List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label @@ -505,6 +506,8 @@ def florence2_sam2_video_tracking( "prompts": [s.strip() for s in prompt.split(",")], "function_name": "florence2_sam2_video_tracking", } + if chunk_length is not None: + payload["chunk_length"] = chunk_length data: Dict[str, Any] = send_inference_request( payload, "florence2-sam2", files=files, v2=True ) @@ -1570,12 +1573,14 @@ def closest_box_distance( # Utility and visualization functions -def extract_frames( +def extract_frames_and_timestamps( video_uri: Union[str, Path], fps: float = 1 -) -> List[Tuple[np.ndarray, float]]: - """'extract_frames' extracts frames from a video which can be a file path, url or - youtube link, returns a list of tuples (frame, timestamp), where timestamp is the - relative time in seconds where the frame was captured. The frame is a numpy array. +) -> List[Dict[str, Union[np.ndarray, float]]]: + """'extract_frames_and_timestamps' extracts frames and timestamps from a video + which can be a file path, url or youtube link, returns a list of dictionaries + with keys "frame" and "timestamp" where "frame" is a numpy array and "timestamp" is + the relative time in seconds where the frame was captured. The frame is a numpy + array. Parameters: video_uri (Union[str, Path]): The path to the video file, url or youtube link @@ -1583,15 +1588,23 @@ def extract_frames( to 1. Returns: - List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame - as a numpy array and the timestamp in seconds. + List[Dict[str, Union[np.ndarray, float]]]: A list of dictionaries containing the + extracted frame as a numpy array and the timestamp in seconds. Example ------- >>> extract_frames("path/to/video.mp4") - [(frame1, 0.0), (frame2, 0.5), ...] + [{"frame": np.ndarray, "timestamp": 0.0}, ...] """ + def reformat( + frames_and_timestamps: List[Tuple[np.ndarray, float]] + ) -> List[Dict[str, Union[np.ndarray, float]]]: + return [ + {"frame": frame, "timestamp": timestamp} + for frame, timestamp in frames_and_timestamps + ] + if str(video_uri).startswith( ( "http://www.youtube.com/", @@ -1613,16 +1626,16 @@ def extract_frames( raise Exception("No suitable video stream found") video_file_path = video.download(output_path=temp_dir) - return extract_frames_from_video(video_file_path, fps) + return reformat(extract_frames_from_video(video_file_path, fps)) elif str(video_uri).startswith(("http", "https")): _, image_suffix = os.path.splitext(video_uri) with tempfile.NamedTemporaryFile(delete=False, suffix=image_suffix) as tmp_file: # Download the video and save it to the temporary file with urllib.request.urlopen(str(video_uri)) as response: tmp_file.write(response.read()) - return extract_frames_from_video(tmp_file.name, fps) + return reformat(extract_frames_from_video(tmp_file.name, fps)) - return extract_frames_from_video(str(video_uri), fps) + return reformat(extract_frames_from_video(str(video_uri), fps)) def save_json(data: Any, file_path: str) -> None: @@ -2026,7 +2039,6 @@ def overlay_counting_results( vit_image_classification, vit_nsfw_classification, countgd_counting, - florence2_image_caption, florence2_ocr, florence2_sam2_image, florence2_sam2_video_tracking, @@ -2041,7 +2053,7 @@ def overlay_counting_results( ] UTIL_TOOLS = [ - extract_frames, + extract_frames_and_timestamps, save_json, load_image, save_image,