Skip to content

Commit

Permalink
update tools
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 25, 2024
1 parent a9c4f57 commit bef5913
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict
]
if "scores" in elt:
elt["scores"] = [round(score, 2) for score in elt["scores"]]
elt["size"] = (image_size[1], image_size[0])
return cast(List[Dict], resp_data)


Expand Down Expand Up @@ -341,6 +342,53 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> str:
return tmp.name


class BboxArea(Tool):
name = "bbox_area_"
description = "'bbox_area_' returns the area of the bounding box in pixels normalized to 2 decimal places."
usage = {
"required_parameters": [{"name": "bbox", "type": "List[int]"}],
"examples": [
{
"scenario": "If you want to calculate the area of the bounding box [0, 0, 100, 100]",
"parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]},
}
],
}

def __call__(self, bboxes: List[Dict]) -> List[Dict]:
areas = []
for elt in bboxes:
height, width = elt["size"]
for label, bbox in zip(elt["labels"], elt["bboxes"]):
x1, y1, x2, y2 = bbox
areas.append(
{
"area": round((x2 - x1) * (y2 - y1) * width * height, 2),
"label": label,
}
)
return areas


class SegArea(Tool):
name = "seg_area_"
description = "'seg_area_' returns the area of the segmentation mask in pixels normalized to 2 decimal places."
usage = {
"required_parameters": [{"name": "masks", "type": "str"}],
"examples": [
{
"scenario": "If you want to calculate the area of the segmentation mask, pass the masks file name.",
"parameters": {"masks": "mask_file.jpg"},
},
],
}

def __call__(self, masks: Union[str, Path]) -> float:
pil_mask = Image.open(str(masks))
np_mask = np.array(pil_mask) # type: ignore
return round(np.sum(np_mask) / 255, 2)


class Add(Tool):
name = "add_"
description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places."
Expand Down Expand Up @@ -418,6 +466,8 @@ def __call__(self, input: List[int]) -> float:
AgentGroundingSAM,
Counter,
Crop,
BboxArea,
SegArea,
Add,
Subtract,
Multiply,
Expand Down

0 comments on commit bef5913

Please sign in to comment.