Skip to content

Commit

Permalink
add distance functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed May 15, 2024
1 parent 16a9285 commit 4f764ad
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import requests
from PIL import Image, ImageDraw, ImageFont
from scipy.spatial import distance

from vision_agent.tools.tool_utils import _send_inference_request
from vision_agent.utils import extract_frames_from_video
Expand Down Expand Up @@ -233,6 +234,54 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
return output


def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float:
"""'closest_mask_distance' calculates the closest distance between two masks.
Parameters:
mask1 (np.ndarray): The first mask.
mask2 (np.ndarray): The second mask.
Returns:
float: The closest distance between the two masks.
Example
-------
>>> closest_mask_distance(mask1, mask2)
0.5
"""

mask1 = np.clip(mask1, 0, 1)
mask2 = np.clip(mask2, 0, 1)
mask1_points = np.transpose(np.nonzero(mask1))
mask2_points = np.transpose(np.nonzero(mask2))
dist_matrix = distance.cdist(mask1_points, mask2_points, "euclidean")
return np.min(dist_matrix)


def closest_box_distance(box1: List[float], box2: List[float]) -> float:
"""'closest_box_distance' calculates the closest distance between two bounding boxes.
Parameters:
box1 (List[float]): The first bounding box.
box2 (List[float]): The second bounding box.
Returns:
float: The closest distance between the two bounding boxes.
Example
-------
>>> closest_box_distance([100, 100, 200, 200], [300, 300, 400, 400])
141.42
"""

x11, y11, x12, y12 = box1
x21, y21, x22, y22 = box2

horizontal_distance = np.max([0, x21 - x12, x11 - x22])
vertical_distance = np.max([0, y21 - y12, y11 - y22])
return np.sqrt(horizontal_distance ** 2 + vertical_distance ** 2)


# Utility and visualization functions


Expand Down Expand Up @@ -429,6 +478,8 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
grounding_sam,
extract_frames,
ocr,
closest_mask_distance,
closest_box_distance,
load_image,
save_image,
overlay_bounding_boxes,
Expand Down

0 comments on commit 4f764ad

Please sign in to comment.