Skip to content

Commit

Permalink
Fix linter errors
Browse files Browse the repository at this point in the history
  • Loading branch information
AsiaCao committed Apr 3, 2024
1 parent f73ef58 commit e48df0e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
1 change: 0 additions & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
continue

for param, call_result in zip(parameters, tool_result["call_results"]):

# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict):
continue
Expand Down
22 changes: 11 additions & 11 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict:
"""
image_size = get_image_size(image)
image_b64 = convert_to_b64(image)
data = {
request_data = {
"prompt": prompt,
"image": image_b64,
"tool": "visual_grounding",
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
json=request_data,
)
resp_json: Dict[str, Any] = res.json()
if (
Expand Down Expand Up @@ -273,31 +273,31 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict:
"""
image_size = get_image_size(image)
image_b64 = convert_to_b64(image)
data = {
request_data = {
"prompt": prompt,
"image": image_b64,
"tool": "visual_grounding_segment",
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
json=request_data,
)
resp_json: Dict[str, Any] = res.json()
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
data = resp_json["data"]
ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []}
data: Dict[str, Any] = resp_json["data"]
ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []}
if "bboxes" in data:
data["bboxes"] = [
normalize_bbox(box, image_size) for box in data["bboxes"]
]
data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]]
if "masks" in data:
data["masks"] = [rle_decode(mask_rle=mask, shape=data["mask_shape"]) for mask in data["masks"][0]]
data["masks"] = [
rle_decode(mask_rle=mask, shape=data["mask_shape"])
for mask in data["masks"][0]
]
return ret_pred


Expand All @@ -306,7 +306,7 @@ class AgentGroundingSAM(GroundingSAM):
returns the file name. This makes it easier for agents to use.
"""

def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict:
rets = super().__call__(prompt, image)
mask_files = []
for mask in rets["masks"]:
Expand Down

0 comments on commit e48df0e

Please sign in to comment.