From e48df0e4812048a4c19dd0887ddf5555657c6eea Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Wed, 3 Apr 2024 10:20:00 -0700 Subject: [PATCH] Fix linter errors --- vision_agent/agent/vision_agent.py | 1 - vision_agent/tools/tools.py | 22 +++++++++++----------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 1aed604f..a072903c 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -368,7 +368,6 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]: continue for param, call_result in zip(parameters, tool_result["call_results"]): - # calls can fail, so we need to check if the call was successful if not isinstance(call_result, dict): continue diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 50f5c9d5..32c250eb 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -180,7 +180,7 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict: """ image_size = get_image_size(image) image_b64 = convert_to_b64(image) - data = { + request_data = { "prompt": prompt, "image": image_b64, "tool": "visual_grounding", @@ -188,7 +188,7 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict: res = requests.post( self._ENDPOINT, headers={"Content-Type": "application/json"}, - json=data, + json=request_data, ) resp_json: Dict[str, Any] = res.json() if ( @@ -273,7 +273,7 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: """ image_size = get_image_size(image) image_b64 = convert_to_b64(image) - data = { + request_data = { "prompt": prompt, "image": image_b64, "tool": "visual_grounding_segment", @@ -281,7 +281,7 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: res = requests.post( self._ENDPOINT, headers={"Content-Type": "application/json"}, - json=data, + json=request_data, ) resp_json: Dict[str, Any] = res.json() if ( @@ -289,15 +289,15 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: ) or "statusCode" not in resp_json: _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") - data = resp_json["data"] - ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} data: Dict[str, Any] = resp_json["data"] + ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} if "bboxes" in data: - data["bboxes"] = [ - normalize_bbox(box, image_size) for box in data["bboxes"] - ] + data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]] if "masks" in data: - data["masks"] = [rle_decode(mask_rle=mask, shape=data["mask_shape"]) for mask in data["masks"][0]] + data["masks"] = [ + rle_decode(mask_rle=mask, shape=data["mask_shape"]) + for mask in data["masks"][0] + ] return ret_pred @@ -306,7 +306,7 @@ class AgentGroundingSAM(GroundingSAM): returns the file name. This makes it easier for agents to use. """ - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: + def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: rets = super().__call__(prompt, image) mask_files = [] for mask in rets["masks"]: