Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new tool: determine if a bbox is contained within another bbox #59

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
_MAX_TABULATE_COL_WIDTH = 80


def parse_json(s: str) -> Any:
Expand Down Expand Up @@ -614,7 +615,7 @@ def retrieval(

self.log_progress(
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
{tabulate(tabular_data=[tool_results], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
Expand Down Expand Up @@ -660,6 +661,6 @@ def create_tasks(
task_list = []
self.log_progress(
f"""Planned tasks:
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
{tabulate(task_list, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)
return task_list
69 changes: 59 additions & 10 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __call__(
iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.

Returns:
A list of dictionaries containing the labels, scores, and bboxes. Each dictionary contains the detection result for an image.
A dictionary containing the labels, scores, and bboxes, which is the detection result for the input image.
"""
image_size = get_image_size(image)
image_b64 = convert_to_b64(image)
Expand Down Expand Up @@ -346,7 +346,7 @@ def __call__(
iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.

Returns:
A list of dictionaries containing the labels, scores, bboxes and masks. Each dictionary contains the segmentation result for an image.
A dictionary containing the labels, scores, bboxes and masks for the input image.
"""
image_size = get_image_size(image)
image_b64 = convert_to_b64(image)
Expand All @@ -357,19 +357,15 @@ def __call__(
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
}
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []}
if "bboxes" in data:
ret_pred["bboxes"] = [
normalize_bbox(box, image_size) for box in data["bboxes"]
]
data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]]
if "masks" in data:
ret_pred["masks"] = [
data["masks"] = [
rle_decode(mask_rle=mask, shape=data["mask_shape"])
for mask in data["masks"]
]
ret_pred["labels"] = data["labels"]
ret_pred["scores"] = data["scores"]
return ret_pred
data.pop("mask_shape", None)
return data


class DINOv(Tool):
Expand Down Expand Up @@ -643,6 +639,58 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
return cast(float, round(iou, 2))


class BboxContains(Tool):
name = "bbox_contains_"
description = "Given two bounding boxes, a target bounding box and a region bounding box, 'bbox_contains_' returns the intersection of the two bounding boxes over the target bounding box, reflects the percentage area of the target bounding box overlaps with the region bounding box. This is a good tool for determining if the region object contains the target object."
dillonalaird marked this conversation as resolved.
Show resolved Hide resolved
usage = {
"required_parameters": [
{"name": "target", "type": "List[int]"},
{"name": "target_class", "type": "str"},
{"name": "region", "type": "List[int]"},
{"name": "region_class", "type": "str"},
],
"examples": [
{
"scenario": "Determine if the dog on the couch, bounding box of the dog: [0.2, 0.21, 0.34, 0.42], bounding box of the couch: [0.3, 0.31, 0.44, 0.52]",
"parameters": {
"target": [0.2, 0.21, 0.34, 0.42],
"target_class": "dog",
"region": [0.3, 0.31, 0.44, 0.52],
"region_class": "couch",
},
},
{
"scenario": "Check if the kid is in the pool? bounding box of the kid: [0.2, 0.21, 0.34, 0.42], bounding box of the pool: [0.3, 0.31, 0.44, 0.52]",
"parameters": {
"target": [0.2, 0.21, 0.34, 0.42],
"target_class": "kid",
"region": [0.3, 0.31, 0.44, 0.52],
"region_class": "pool",
},
},
],
}

def __call__(
self, target: List[int], target_class: str, region: List[int], region_class: str
) -> Dict[str, Union[str, float]]:
x1, y1, x2, y2 = target
x3, y3, x4, y4 = region
xA = max(x1, x3)
yA = max(y1, y3)
xB = min(x2, x4)
yB = min(y2, y4)
inter_area = max(0, xB - xA) * max(0, yB - yA)
boxa_area = (x2 - x1) * (y2 - y1)
iou = inter_area / float(boxa_area)
area = round(iou, 2)
return {
"target_class": target_class,
"region_class": region_class,
"intersection": area,
}


class BoxDistance(Tool):
name = "box_distance_"
description = (
Expand Down Expand Up @@ -757,6 +805,7 @@ def __call__(self, equation: str) -> float:
SegArea,
BboxIoU,
SegIoU,
BboxContains,
BoxDistance,
Calculator,
]
Expand Down
Loading