Skip to content

Commit

Permalink
Add a new tool: determine if a bbox is contained within another bbox (#…
Browse files Browse the repository at this point in the history
…59)

* Add a new bounding box contains tool

* Fix format
  • Loading branch information
AsiaCao authored Apr 22, 2024
1 parent c505b4e commit eb101f0
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 12 deletions.
5 changes: 3 additions & 2 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
_MAX_TABULATE_COL_WIDTH = 80


def parse_json(s: str) -> Any:
Expand Down Expand Up @@ -614,7 +615,7 @@ def retrieval(

self.log_progress(
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
{tabulate(tabular_data=[tool_results], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
Expand Down Expand Up @@ -660,6 +661,6 @@ def create_tasks(
task_list = []
self.log_progress(
f"""Planned tasks:
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
{tabulate(task_list, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)
return task_list
69 changes: 59 additions & 10 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __call__(
iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.
Returns:
A list of dictionaries containing the labels, scores, and bboxes. Each dictionary contains the detection result for an image.
A dictionary containing the labels, scores, and bboxes, which is the detection result for the input image.
"""
image_size = get_image_size(image)
image_b64 = convert_to_b64(image)
Expand Down Expand Up @@ -346,7 +346,7 @@ def __call__(
iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.
Returns:
A list of dictionaries containing the labels, scores, bboxes and masks. Each dictionary contains the segmentation result for an image.
A dictionary containing the labels, scores, bboxes and masks for the input image.
"""
image_size = get_image_size(image)
image_b64 = convert_to_b64(image)
Expand All @@ -357,19 +357,15 @@ def __call__(
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
}
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []}
if "bboxes" in data:
ret_pred["bboxes"] = [
normalize_bbox(box, image_size) for box in data["bboxes"]
]
data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]]
if "masks" in data:
ret_pred["masks"] = [
data["masks"] = [
rle_decode(mask_rle=mask, shape=data["mask_shape"])
for mask in data["masks"]
]
ret_pred["labels"] = data["labels"]
ret_pred["scores"] = data["scores"]
return ret_pred
data.pop("mask_shape", None)
return data


class DINOv(Tool):
Expand Down Expand Up @@ -643,6 +639,58 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
return cast(float, round(iou, 2))


class BboxContains(Tool):
name = "bbox_contains_"
description = "Given two bounding boxes, a target bounding box and a region bounding box, 'bbox_contains_' returns the intersection of the two bounding boxes over the target bounding box, reflects the percentage area of the target bounding box overlaps with the region bounding box. This is a good tool for determining if the region object contains the target object."
usage = {
"required_parameters": [
{"name": "target", "type": "List[int]"},
{"name": "target_class", "type": "str"},
{"name": "region", "type": "List[int]"},
{"name": "region_class", "type": "str"},
],
"examples": [
{
"scenario": "Determine if the dog on the couch, bounding box of the dog: [0.2, 0.21, 0.34, 0.42], bounding box of the couch: [0.3, 0.31, 0.44, 0.52]",
"parameters": {
"target": [0.2, 0.21, 0.34, 0.42],
"target_class": "dog",
"region": [0.3, 0.31, 0.44, 0.52],
"region_class": "couch",
},
},
{
"scenario": "Check if the kid is in the pool? bounding box of the kid: [0.2, 0.21, 0.34, 0.42], bounding box of the pool: [0.3, 0.31, 0.44, 0.52]",
"parameters": {
"target": [0.2, 0.21, 0.34, 0.42],
"target_class": "kid",
"region": [0.3, 0.31, 0.44, 0.52],
"region_class": "pool",
},
},
],
}

def __call__(
self, target: List[int], target_class: str, region: List[int], region_class: str
) -> Dict[str, Union[str, float]]:
x1, y1, x2, y2 = target
x3, y3, x4, y4 = region
xA = max(x1, x3)
yA = max(y1, y3)
xB = min(x2, x4)
yB = min(y2, y4)
inter_area = max(0, xB - xA) * max(0, yB - yA)
boxa_area = (x2 - x1) * (y2 - y1)
iou = inter_area / float(boxa_area)
area = round(iou, 2)
return {
"target_class": target_class,
"region_class": region_class,
"intersection": area,
}


class BoxDistance(Tool):
name = "box_distance_"
description = (
Expand Down Expand Up @@ -757,6 +805,7 @@ def __call__(self, equation: str) -> float:
SegArea,
BboxIoU,
SegIoU,
BboxContains,
BoxDistance,
Calculator,
]
Expand Down

0 comments on commit eb101f0

Please sign in to comment.