Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding more tools #84

Merged
merged 7 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions vision_agent/agent/vision_agent_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,11 @@ def run_plan(
f"""
{tabulate(tabular_data=[task], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)
tool_info = "\n".join(
[e["doc"] for e in tool_recommender.top_k(task["instruction"])]
)
tools = tool_recommender.top_k(task["instruction"])
tool_info = "\n".join([e["doc"] for e in tools])

if verbosity == 2:
_LOGGER.info(f"Tools retrieved: {[e['desc'] for e in tools]}")

if long_term_memory is not None:
retrieved_ltm = "\n".join(
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
CLIP,
OCR,
TOOLS,
BboxStats,
BboxIoU,
BboxStats,
BoxDistance,
Crop,
DINOv,
Expand Down
53 changes: 52 additions & 1 deletion vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import tempfile
from importlib import resources
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Union, cast

import numpy as np
import pandas as pd
import requests
from PIL import Image, ImageDraw, ImageFont
from scipy.spatial import distance # type: ignore

from vision_agent.tools.tool_utils import _send_inference_request
from vision_agent.utils import extract_frames_from_video
Expand Down Expand Up @@ -233,6 +234,54 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
return output


def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float:
"""'closest_mask_distance' calculates the closest distance between two masks.

Parameters:
mask1 (np.ndarray): The first mask.
mask2 (np.ndarray): The second mask.

Returns:
float: The closest distance between the two masks.

Example
-------
>>> closest_mask_distance(mask1, mask2)
0.5
"""

mask1 = np.clip(mask1, 0, 1)
mask2 = np.clip(mask2, 0, 1)
mask1_points = np.transpose(np.nonzero(mask1))
mask2_points = np.transpose(np.nonzero(mask2))
dist_matrix = distance.cdist(mask1_points, mask2_points, "euclidean")
return cast(float, np.min(dist_matrix))


def closest_box_distance(box1: List[float], box2: List[float]) -> float:
"""'closest_box_distance' calculates the closest distance between two bounding boxes.

Parameters:
box1 (List[float]): The first bounding box.
box2 (List[float]): The second bounding box.

Returns:
float: The closest distance between the two bounding boxes.

Example
-------
>>> closest_box_distance([100, 100, 200, 200], [300, 300, 400, 400])
141.42
"""

x11, y11, x12, y12 = box1
x21, y21, x22, y22 = box2

horizontal_distance = np.max([0, x21 - x12, x11 - x22])
vertical_distance = np.max([0, y21 - y12, y11 - y22])
return cast(float, np.sqrt(horizontal_distance**2 + vertical_distance**2))


# Utility and visualization functions


Expand Down Expand Up @@ -429,6 +478,8 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
grounding_sam,
extract_frames,
ocr,
closest_mask_distance,
closest_box_distance,
load_image,
save_image,
overlay_bounding_boxes,
Expand Down
Loading