From d95fb46c92767bc965873c8ea9b811153b950b2e Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Tue, 3 Sep 2024 10:13:28 -0700 Subject: [PATCH] resolving pr comments on filter by score --- vision_agent/tools/tool_utils.py | 6 ------ vision_agent/tools/tools.py | 9 ++++++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index a14443bd..74bb74b6 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -233,9 +233,3 @@ def _call_post( trace = tool_call_trace.model_dump() trace["type"] = "tool_call" display({MimeType.APPLICATION_JSON: trace}, raw=True) - - -def filter_bboxes_by_threshold( - bboxes: BoundingBoxes, threshold: float -) -> BoundingBoxes: - return list(filter(lambda bbox: bbox.score >= threshold, bboxes)) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index e0961398..ef9bac54 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -25,7 +25,6 @@ get_tools_info, send_inference_request, send_task_inference_request, - filter_bboxes_by_threshold, ) from vision_agent.tools.tools_types import ( BboxInput, @@ -547,7 +546,9 @@ def countgd_counting( ) bboxes_per_frame = resp_data[0] bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] - filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + filtered_bboxes = list( + filter(lambda bbox: bbox.score >= box_threshold, bboxes_formatted) + ) return [bbox.model_dump() for bbox in filtered_bboxes] @@ -599,7 +600,9 @@ def countgd_example_based_counting( ) bboxes_per_frame = resp_data[0] bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] - filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + filtered_bboxes = list( + filter(lambda bbox: bbox.score >= box_threshold, bboxes_formatted) + ) return [bbox.model_dump() for bbox in filtered_bboxes]