From 3f0296548485e617ca8966c172ec9d978352a015 Mon Sep 17 00:00:00 2001 From: shankar_ws3 Date: Tue, 9 Apr 2024 10:21:02 -0700 Subject: [PATCH] fix mypy errors --- vision_agent/tools/tools.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 96aed1bd..36a78f58 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -124,7 +124,7 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: round(prob, 4) for prob in resp_json["data"]["scores"] ] - return resp_json["data"] + return resp_json["data"] # type: ignore class GroundingDINO(Tool): @@ -340,8 +340,14 @@ class AgentGroundingSAM(GroundingSAM): returns the file name. This makes it easier for agents to use. """ - def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: - rets = super().__call__(prompt, image) + def __call__( + self, + prompt: str, + image: Union[str, ImageType], + box_threshold: float = 0.2, + iou_threshold: float = 0.75, + ) -> Dict: + rets = super().__call__(prompt, image, box_threshold, iou_threshold) mask_files = [] for mask in rets["masks"]: with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: