Skip to content

Commit

Permalink
rename functions to make them easier to understand by llm
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 22, 2024
1 parent 94f9501 commit 85e2e8a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
depth_anything_v2,
detr_segmentation,
dpt_hybrid_midas,
extract_frames,
extract_frames_and_timestamps,
florence2_image_caption,
florence2_ocr,
florence2_phrase_grounding,
Expand Down
54 changes: 33 additions & 21 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,10 @@ def owl_v2_video(
box_threshold: float = 0.10,
) -> List[List[Dict[str, Any]]]:
"""'owl_v2_video' will run owl_v2 on each frame of a video. It can detect multiple
objects per frame given a text prompt sucha s a category name or referring
expression. The categories in text prompt are separated by commas. It returns a list
of lists where each inner list contains the score, label, and bounding box of the
detections for that frame.
objects indepdently per frame given a text prompt such as a category name or
referring expression but does not track objects across frames. The categories in
text prompt are separated by commas. It returns a list of lists where each inner
list contains the score, label, and bounding box of the detections for that frame.
Parameters:
prompt (str): The prompt to ground to the video.
Expand Down Expand Up @@ -461,18 +461,19 @@ def florence2_sam2_image(


def florence2_sam2_video_tracking(
prompt: str, frames: List[np.ndarray]
prompt: str, frames: List[np.ndarray], chunk_length: Optional[int] = None
) -> List[List[Dict[str, Any]]]:
"""'florence2_sam2_video_tracking' is a tool that can segment and track multiple
entities in a video given a text prompt such as category names or referring
expressions. You can optionally separate the categories in the text with commas. It
only tracks entities present in the first frame and only returns segmentation
masks. It is useful for tracking and counting without duplicating counts if they
appear in the first frame, always outputs scores of 1.0.
can find new objects every 'chunk_length' frames and is useful for tracking and
counting without duplicating counts and always outputs scores of 1.0.
Parameters:
prompt (str): The prompt to ground to the video.
frames (List[np.ndarray]): The list of frames to ground the prompt to.
chunk_length (Optional[int]): The number of frames to re-run florence2 to find
new objects.
Returns:
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label
Expand Down Expand Up @@ -505,6 +506,8 @@ def florence2_sam2_video_tracking(
"prompts": [s.strip() for s in prompt.split(",")],
"function_name": "florence2_sam2_video_tracking",
}
if chunk_length is not None:
payload["chunk_length"] = chunk_length
data: Dict[str, Any] = send_inference_request(
payload, "florence2-sam2", files=files, v2=True
)
Expand Down Expand Up @@ -1570,28 +1573,38 @@ def closest_box_distance(
# Utility and visualization functions


def extract_frames(
def extract_frames_and_timestamps(
video_uri: Union[str, Path], fps: float = 1
) -> List[Tuple[np.ndarray, float]]:
"""'extract_frames' extracts frames from a video which can be a file path, url or
youtube link, returns a list of tuples (frame, timestamp), where timestamp is the
relative time in seconds where the frame was captured. The frame is a numpy array.
) -> List[Dict[str, Union[np.ndarray, float]]]:
"""'extract_frames_and_timestamps' extracts frames and timestamps from a video
which can be a file path, url or youtube link, returns a list of dictionaries
with keys "frame" and "timestamp" where "frame" is a numpy array and "timestamp" is
the relative time in seconds where the frame was captured. The frame is a numpy
array.
Parameters:
video_uri (Union[str, Path]): The path to the video file, url or youtube link
fps (float, optional): The frame rate per second to extract the frames. Defaults
to 1.
Returns:
List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame
as a numpy array and the timestamp in seconds.
List[Dict[str, Union[np.ndarray, float]]]: A list of dictionaries containing the
extracted frame as a numpy array and the timestamp in seconds.
Example
-------
>>> extract_frames("path/to/video.mp4")
[(frame1, 0.0), (frame2, 0.5), ...]
[{"frame": np.ndarray, "timestamp": 0.0}, ...]
"""

def reformat(
frames_and_timestamps: List[Tuple[np.ndarray, float]]
) -> List[Dict[str, Union[np.ndarray, float]]]:
return [
{"frame": frame, "timestamp": timestamp}
for frame, timestamp in frames_and_timestamps
]

if str(video_uri).startswith(
(
"http://www.youtube.com/",
Expand All @@ -1613,16 +1626,16 @@ def extract_frames(
raise Exception("No suitable video stream found")
video_file_path = video.download(output_path=temp_dir)

return extract_frames_from_video(video_file_path, fps)
return reformat(extract_frames_from_video(video_file_path, fps))
elif str(video_uri).startswith(("http", "https")):
_, image_suffix = os.path.splitext(video_uri)
with tempfile.NamedTemporaryFile(delete=False, suffix=image_suffix) as tmp_file:
# Download the video and save it to the temporary file
with urllib.request.urlopen(str(video_uri)) as response:
tmp_file.write(response.read())
return extract_frames_from_video(tmp_file.name, fps)
return reformat(extract_frames_from_video(tmp_file.name, fps))

return extract_frames_from_video(str(video_uri), fps)
return reformat(extract_frames_from_video(str(video_uri), fps))


def save_json(data: Any, file_path: str) -> None:
Expand Down Expand Up @@ -2026,7 +2039,6 @@ def overlay_counting_results(
vit_image_classification,
vit_nsfw_classification,
countgd_counting,
florence2_image_caption,
florence2_ocr,
florence2_sam2_image,
florence2_sam2_video_tracking,
Expand All @@ -2041,7 +2053,7 @@ def overlay_counting_results(
]

UTIL_TOOLS = [
extract_frames,
extract_frames_and_timestamps,
save_json,
load_image,
save_image,
Expand Down

0 comments on commit 85e2e8a

Please sign in to comment.