Skip to content

Commit

Permalink
Adding more tools (#84)
Browse files Browse the repository at this point in the history
* added different verbosity levels, better json parsing

* fix typing error

* log retrieved functions

* add distance functions

* fix types

* fix types

* fix formatting
  • Loading branch information
dillonalaird authored May 15, 2024
1 parent 987896d commit d8ef603
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
8 changes: 5 additions & 3 deletions vision_agent/agent/vision_agent_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,11 @@ def run_plan(
f"""
{tabulate(tabular_data=[task], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)
tool_info = "\n".join(
[e["doc"] for e in tool_recommender.top_k(task["instruction"])]
)
tools = tool_recommender.top_k(task["instruction"])
tool_info = "\n".join([e["doc"] for e in tools])

if verbosity == 2:
_LOGGER.info(f"Tools retrieved: {[e['desc'] for e in tools]}")

if long_term_memory is not None:
retrieved_ltm = "\n".join(
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
CLIP,
OCR,
TOOLS,
BboxStats,
BboxIoU,
BboxStats,
BoxDistance,
Crop,
DINOv,
Expand Down
53 changes: 52 additions & 1 deletion vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import tempfile
from importlib import resources
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Union, cast

import numpy as np
import pandas as pd
import requests
from PIL import Image, ImageDraw, ImageFont
from scipy.spatial import distance # type: ignore

from vision_agent.tools.tool_utils import _send_inference_request
from vision_agent.utils import extract_frames_from_video
Expand Down Expand Up @@ -233,6 +234,54 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
return output


def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float:
"""'closest_mask_distance' calculates the closest distance between two masks.
Parameters:
mask1 (np.ndarray): The first mask.
mask2 (np.ndarray): The second mask.
Returns:
float: The closest distance between the two masks.
Example
-------
>>> closest_mask_distance(mask1, mask2)
0.5
"""

mask1 = np.clip(mask1, 0, 1)
mask2 = np.clip(mask2, 0, 1)
mask1_points = np.transpose(np.nonzero(mask1))
mask2_points = np.transpose(np.nonzero(mask2))
dist_matrix = distance.cdist(mask1_points, mask2_points, "euclidean")
return cast(float, np.min(dist_matrix))


def closest_box_distance(box1: List[float], box2: List[float]) -> float:
"""'closest_box_distance' calculates the closest distance between two bounding boxes.
Parameters:
box1 (List[float]): The first bounding box.
box2 (List[float]): The second bounding box.
Returns:
float: The closest distance between the two bounding boxes.
Example
-------
>>> closest_box_distance([100, 100, 200, 200], [300, 300, 400, 400])
141.42
"""

x11, y11, x12, y12 = box1
x21, y21, x22, y22 = box2

horizontal_distance = np.max([0, x21 - x12, x11 - x22])
vertical_distance = np.max([0, y21 - y12, y11 - y22])
return cast(float, np.sqrt(horizontal_distance**2 + vertical_distance**2))


# Utility and visualization functions


Expand Down Expand Up @@ -429,6 +478,8 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
grounding_sam,
extract_frames,
ocr,
closest_mask_distance,
closest_box_distance,
load_image,
save_image,
overlay_bounding_boxes,
Expand Down

0 comments on commit d8ef603

Please sign in to comment.