Skip to content

Commit

Permalink
updated prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 21, 2024
1 parent 968b3df commit c4e50c8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
4 changes: 3 additions & 1 deletion vision_agent/agent/vision_agent_coder_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"plan1":
[
{{
"thoughts": str # your thought process for this plan
"instructions": str # what you should do in this task associated with a tool
}}
],
Expand Down Expand Up @@ -127,7 +128,8 @@
**Instructions**:
1. Given the plans, image, and tool outputs, decide which plan is the best to achieve the user request.
2. Output a JSON object with the following format:
2. Try solving the problem yourself given the image and pick the plan which matches your solution the best.
3. Output a JSON object with the following format:
{{
"thoughts": str # your thought process for choosing the best plan
"best_plan": str # the best plan you have chosen
Expand Down
40 changes: 27 additions & 13 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ def florence2_sam2_video(
) -> List[List[Dict[str, Any]]]:
"""'florence2_sam2_video' is a tool that can segment and track multiple objects
in a video given a text prompt such as category names or referring expressions. The
categories in the text prompt are separated by commas. It returns tracked objects
as masks, labels, and scores for each frame.
categories in the text prompt are separated by commas. It is useful for tracking
and counting across frames without counting duplicates.
Parameters:
prompt (str): The prompt to ground to the video.
Expand Down Expand Up @@ -421,12 +421,19 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
Returns:
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
value. E.g. {count: 12}.
value, e.g. {count: 12} and a heat map for visaulization purposes.
Example
-------
>>> loca_zero_shot_counting(image)
{'count': 45},
{'count': 83,
'heat_map': array([[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 1],
...,
[ 0, 0, 0, ..., 30, 35, 41],
[ 0, 0, 0, ..., 41, 47, 53],
[ 0, 0, 0, ..., 53, 59, 64]], dtype=uint8)}
"""

image_b64 = convert_to_b64(image)
Expand All @@ -451,12 +458,19 @@ def loca_visual_prompt_counting(
Returns:
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
value. E.g. {count: 12}.
value, e.g. {count: 12} and a heat map for visaulization purposes.
Example
-------
>>> loca_visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]})
{'count': 45},
{'count': 83,
'heat_map': array([[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 1],
...,
[ 0, 0, 0, ..., 30, 35, 41],
[ 0, 0, 0, ..., 41, 47, 53],
[ 0, 0, 0, ..., 53, 59, 64]], dtype=uint8)}
"""

image_size = get_image_size(image)
Expand Down Expand Up @@ -1138,7 +1152,7 @@ def closest_box_distance(


def extract_frames(
video_uri: Union[str, Path], fps: float = 0.5
video_uri: Union[str, Path], fps: float = 1
) -> List[Tuple[np.ndarray, float]]:
"""'extract_frames' extracts frames from a video which can be a file path or youtube
link, returns a list of tuples (frame, timestamp), where timestamp is the relative
Expand All @@ -1147,7 +1161,7 @@ def extract_frames(
Parameters:
video_uri (Union[str, Path]): The path to the video file or youtube link
fps (float, optional): The frame rate per second to extract the frames. Defaults
to 0.5.
to 10.
Returns:
List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame
Expand Down Expand Up @@ -1249,7 +1263,7 @@ def save_image(image: np.ndarray, file_path: str) -> None:


def save_video(
frames: List[np.ndarray], output_video_path: Optional[str] = None, fps: float = 4
frames: List[np.ndarray], output_video_path: Optional[str] = None, fps: float = 1
) -> str:
"""'save_video' is a utility function that saves a list of frames as a mp4 video file on disk.
Expand Down Expand Up @@ -1500,21 +1514,21 @@ def overlay_heat_map(

TOOLS = [
owl_v2,
grounding_sam,
extract_frames,
ocr,
clip,
vit_image_classification,
vit_nsfw_classification,
loca_zero_shot_counting,
loca_visual_prompt_counting,
florence2_roberta_vqa,
florence2_image_caption,
florence2_ocr,
florence2_sam2_image,
florence2_sam2_video,
ixc25_image_vqa,
ixc25_video_vqa,
detr_segmentation,
depth_anything_v2,
generate_soft_edge_image,
dpt_hybrid_midas,
generate_pose_image,
closest_mask_distance,
closest_box_distance,
Expand Down

0 comments on commit c4e50c8

Please sign in to comment.