Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
Dayof committed Sep 2, 2024
1 parent becc7f3 commit f8e05ee
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
12 changes: 6 additions & 6 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def send_inference_request(
files: Optional[List[Tuple[Any, ...]]] = None,
v2: bool = False,
metadata_payload: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
) -> Any:
# TODO: runtime_tag and function_name should be metadata_payload and now included
# in the service payload
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
Expand Down Expand Up @@ -70,11 +70,11 @@ def send_inference_request(

def send_task_inference_request(
payload: Dict[str, Any],
endpoint_name: str,
task_name: str,
files: Optional[List[Tuple[Any, ...]]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
url = f"{_LND_API_URL_v2}/{endpoint_name}"
) -> Any:
url = f"{_LND_API_URL_v2}/{task_name}"
headers = {"apikey": _LND_API_KEY}
session = _create_requests_session(
url=url,
Expand Down Expand Up @@ -201,7 +201,7 @@ def _call_post(
session: Session,
files: Optional[List[Tuple[Any, ...]]] = None,
function_name: str = "unknown",
) -> dict[str, Any]:
) -> Any:
try:
tool_call_trace = ToolCallTrace(
endpoint_url=url,
Expand Down Expand Up @@ -238,4 +238,4 @@ def _call_post(
def filter_bboxes_by_threshold(
bboxes: BoundingBoxes, threshold: float
) -> BoundingBoxes:
return list(map(lambda bbox: bbox["score"] >= threshold, bboxes))
return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
24 changes: 14 additions & 10 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
"image": image_b64,
"function_name": "loca_zero_shot_counting",
}
resp_data = send_inference_request(data, "loca", v2=True)
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
return resp_data

Expand Down Expand Up @@ -501,7 +501,7 @@ def loca_visual_prompt_counting(
"bbox": list(map(int, denormalize_bbox(bbox, image_size))),
"function_name": "loca_visual_prompt_counting",
}
resp_data = send_inference_request(data, "loca", v2=True)
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
return resp_data

Expand Down Expand Up @@ -542,12 +542,13 @@ def countgd_counting(
files = [("image", buffer_bytes)]
payload = {"prompts": [prompt], "model": "countgd"}
metadata = {"function_name": "countgd_counting"}
resp_data: List[Dict[str, Any]] = send_task_inference_request(
resp_data = send_task_inference_request(
payload, "text-to-object-detection", files=files, metadata=metadata
)
bboxes_per_frame = resp_data[0]
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
return filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
return [bbox.model_dump() for bbox in filtered_bboxes]


def countgd_example_based_counting(
Expand Down Expand Up @@ -591,14 +592,15 @@ def countgd_example_based_counting(
visual_prompts = [
denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
]
payload = {"visual_prompts": json.loads(visual_prompts), "model": "countgd"}
payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"}
metadata = {"function_name": "countgd_example_based_counting"}
resp_data: List[Dict[str, Any]] = send_task_inference_request(
resp_data = send_task_inference_request(
payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
)
bboxes_per_frame = resp_data[0]
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
return filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
return [bbox.model_dump() for bbox in filtered_bboxes]


def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
Expand Down Expand Up @@ -746,7 +748,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
"tool": "closed_set_image_classification",
"function_name": "clip",
}
resp_data = send_inference_request(data, "tools")
resp_data: dict[str, Any] = send_inference_request(data, "tools")
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
return resp_data

Expand Down Expand Up @@ -774,7 +776,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
"tool": "image_classification",
"function_name": "vit_image_classification",
}
resp_data = send_inference_request(data, "tools")
resp_data: dict[str, Any] = send_inference_request(data, "tools")
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
return resp_data

Expand All @@ -801,7 +803,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
"image": image_b64,
"function_name": "vit_nsfw_classification",
}
resp_data = send_inference_request(data, "nsfw-classification", v2=True)
resp_data: dict[str, Any] = send_inference_request(
data, "nsfw-classification", v2=True
)
resp_data["score"] = round(resp_data["score"], 4)
return resp_data

Expand Down

0 comments on commit f8e05ee

Please sign in to comment.