diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 32c250eb..164ca6ea 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -292,12 +292,15 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: data: Dict[str, Any] = resp_json["data"] ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} if "bboxes" in data: - data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]] + ret_pred["bboxes"] = [ + normalize_bbox(box, image_size) for box in data["bboxes"] + ] if "masks" in data: - data["masks"] = [ + ret_pred["masks"] = [ rle_decode(mask_rle=mask, shape=data["mask_shape"]) - for mask in data["masks"][0] + for mask in data["masks"] ] + ret_pred["labels"] = data["labels"] return ret_pred