From bef5913aa4cfa93b4c593e1583305c0e11b4d367 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sun, 24 Mar 2024 20:38:56 -0700 Subject: [PATCH] update tools --- vision_agent/tools/tools.py | 50 +++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index fdcece18..1fbbd182 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -180,6 +180,7 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict ] if "scores" in elt: elt["scores"] = [round(score, 2) for score in elt["scores"]] + elt["size"] = (image_size[1], image_size[0]) return cast(List[Dict], resp_data) @@ -341,6 +342,53 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> str: return tmp.name +class BboxArea(Tool): + name = "bbox_area_" + description = "'bbox_area_' returns the area of the bounding box in pixels normalized to 2 decimal places." + usage = { + "required_parameters": [{"name": "bbox", "type": "List[int]"}], + "examples": [ + { + "scenario": "If you want to calculate the area of the bounding box [0, 0, 100, 100]", + "parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]}, + } + ], + } + + def __call__(self, bboxes: List[Dict]) -> List[Dict]: + areas = [] + for elt in bboxes: + height, width = elt["size"] + for label, bbox in zip(elt["labels"], elt["bboxes"]): + x1, y1, x2, y2 = bbox + areas.append( + { + "area": round((x2 - x1) * (y2 - y1) * width * height, 2), + "label": label, + } + ) + return areas + + +class SegArea(Tool): + name = "seg_area_" + description = "'seg_area_' returns the area of the segmentation mask in pixels normalized to 2 decimal places." + usage = { + "required_parameters": [{"name": "masks", "type": "str"}], + "examples": [ + { + "scenario": "If you want to calculate the area of the segmentation mask, pass the masks file name.", + "parameters": {"masks": "mask_file.jpg"}, + }, + ], + } + + def __call__(self, masks: Union[str, Path]) -> float: + pil_mask = Image.open(str(masks)) + np_mask = np.array(pil_mask) # type: ignore + return round(np.sum(np_mask) / 255, 2) + + class Add(Tool): name = "add_" description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places." @@ -418,6 +466,8 @@ def __call__(self, input: List[int]) -> float: AgentGroundingSAM, Counter, Crop, + BboxArea, + SegArea, Add, Subtract, Multiply,