diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index bbd2c1a5..3287a174 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -33,6 +33,7 @@ logging.basicConfig(stream=sys.stdout) _LOGGER = logging.getLogger(__name__) +_MAX_TABULATE_COL_WIDTH = 80 def parse_json(s: str) -> Any: @@ -614,7 +615,7 @@ def retrieval( self.log_progress( f"""Going to run the following tool(s) in sequence: -{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}""" +{tabulate(tabular_data=[tool_results], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}""" ) def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: @@ -660,6 +661,6 @@ def create_tasks( task_list = [] self.log_progress( f"""Planned tasks: -{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}""" +{tabulate(task_list, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}""" ) return task_list diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 6d2a7b47..7657a362 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -250,7 +250,7 @@ def __call__( iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold. Returns: - A list of dictionaries containing the labels, scores, and bboxes. Each dictionary contains the detection result for an image. + A dictionary containing the labels, scores, and bboxes, which is the detection result for the input image. """ image_size = get_image_size(image) image_b64 = convert_to_b64(image) @@ -346,7 +346,7 @@ def __call__( iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold. Returns: - A list of dictionaries containing the labels, scores, bboxes and masks. Each dictionary contains the segmentation result for an image. + A dictionary containing the labels, scores, bboxes and masks for the input image. """ image_size = get_image_size(image) image_b64 = convert_to_b64(image) @@ -357,19 +357,15 @@ def __call__( "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } data: Dict[str, Any] = _send_inference_request(request_data, "tools") - ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} if "bboxes" in data: - ret_pred["bboxes"] = [ - normalize_bbox(box, image_size) for box in data["bboxes"] - ] + data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]] if "masks" in data: - ret_pred["masks"] = [ + data["masks"] = [ rle_decode(mask_rle=mask, shape=data["mask_shape"]) for mask in data["masks"] ] - ret_pred["labels"] = data["labels"] - ret_pred["scores"] = data["scores"] - return ret_pred + data.pop("mask_shape", None) + return data class DINOv(Tool): @@ -643,6 +639,58 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float: return cast(float, round(iou, 2)) +class BboxContains(Tool): + name = "bbox_contains_" + description = "Given two bounding boxes, a target bounding box and a region bounding box, 'bbox_contains_' returns the intersection of the two bounding boxes over the target bounding box, reflects the percentage area of the target bounding box overlaps with the region bounding box. This is a good tool for determining if the region object contains the target object." + usage = { + "required_parameters": [ + {"name": "target", "type": "List[int]"}, + {"name": "target_class", "type": "str"}, + {"name": "region", "type": "List[int]"}, + {"name": "region_class", "type": "str"}, + ], + "examples": [ + { + "scenario": "Determine if the dog on the couch, bounding box of the dog: [0.2, 0.21, 0.34, 0.42], bounding box of the couch: [0.3, 0.31, 0.44, 0.52]", + "parameters": { + "target": [0.2, 0.21, 0.34, 0.42], + "target_class": "dog", + "region": [0.3, 0.31, 0.44, 0.52], + "region_class": "couch", + }, + }, + { + "scenario": "Check if the kid is in the pool? bounding box of the kid: [0.2, 0.21, 0.34, 0.42], bounding box of the pool: [0.3, 0.31, 0.44, 0.52]", + "parameters": { + "target": [0.2, 0.21, 0.34, 0.42], + "target_class": "kid", + "region": [0.3, 0.31, 0.44, 0.52], + "region_class": "pool", + }, + }, + ], + } + + def __call__( + self, target: List[int], target_class: str, region: List[int], region_class: str + ) -> Dict[str, Union[str, float]]: + x1, y1, x2, y2 = target + x3, y3, x4, y4 = region + xA = max(x1, x3) + yA = max(y1, y3) + xB = min(x2, x4) + yB = min(y2, y4) + inter_area = max(0, xB - xA) * max(0, yB - yA) + boxa_area = (x2 - x1) * (y2 - y1) + iou = inter_area / float(boxa_area) + area = round(iou, 2) + return { + "target_class": target_class, + "region_class": region_class, + "intersection": area, + } + + class BoxDistance(Tool): name = "box_distance_" description = ( @@ -757,6 +805,7 @@ def __call__(self, equation: str) -> float: SegArea, BboxIoU, SegIoU, + BboxContains, BoxDistance, Calculator, ]