diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index f66f576c..403c1c15 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -253,14 +253,15 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") resp_data = resp_json["data"] - ret_pred = {"labels": [], "bboxes": [], "masks": []} + ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} for pred in resp_data["preds"]: encoded_mask = pred["encoded_mask"] mask = rle_decode(mask_rle=encoded_mask, shape=pred["mask_shape"]) ret_pred["labels"].append(pred["label_name"]) ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size)) ret_pred["masks"].append(mask) - return [ret_pred] + ret_preds = [ret_pred] + return ret_preds class AgentGroundingSAM(GroundingSAM):