From 88658aa9835fb49f5512d7289ce2d5f8afdd582e Mon Sep 17 00:00:00 2001 From: Camilo Zapata Date: Fri, 4 Oct 2024 16:14:55 -0500 Subject: [PATCH] fix countgd api to match the UI --- vision_agent/tools/tools.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 7d881921..b33df8ec 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -700,18 +700,22 @@ def countgd_counting( {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, ] """ - image_b64 = convert_to_b64(image) + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] prompt = prompt.replace(", ", " .") - payload = {"prompt": prompt, "image": image_b64} + payload = {"prompts": [prompt], "model": "countgd"} metadata = {"function_name": "countgd_counting"} - resp_data = send_task_inference_request(payload, "countgd", metadata=metadata) + resp_data = send_task_inference_request( + payload, "text-to-object-detection", files=files, metadata=metadata + ) + bboxes_per_frame = resp_data[0] bboxes_formatted = [ ODResponseData( label=bbox["label"], - bbox=list(map(lambda x: round(x, 2), bbox["bbox"])), + bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])), score=round(bbox["score"], 2), ) - for bbox in resp_data + for bbox in bboxes_per_frame ] filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) return [bbox.model_dump() for bbox in filtered_bboxes]