diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index 3576e10c..ce25f286 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -83,9 +83,6 @@ jobs: - name: Test with pytest run: | poetry run pytest -v tests/integ - - name: Test with pytest, dev env - run: | - LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev release: name: Release diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 8c01f78d..9958894d 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -32,6 +32,8 @@ template_match, vit_image_classification, vit_nsfw_classification, + countgd_counting, + countgd_example_based_counting, ) FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da" @@ -387,3 +389,18 @@ def test_generate_hed(): ) assert result.shape == img.shape + + +def test_countgd_counting() -> None: + img = ski.data.coins() + result = countgd_counting(image=img, prompt="coin") + assert len(result) == 24 + + +def test_countgd_example_based_counting() -> None: + img = ski.data.coins() + result = countgd_example_based_counting( + visual_prompts=[[85, 106, 122, 145]], + image=img, + ) + assert len(result) == 24 diff --git a/tests/integration_dev/__init__.py b/tests/integration_dev/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration_dev/test_tools.py b/tests/integration_dev/test_tools.py deleted file mode 100644 index 246c5642..00000000 --- a/tests/integration_dev/test_tools.py +++ /dev/null @@ -1,18 +0,0 @@ -import skimage as ski - -from vision_agent.tools import countgd_counting, countgd_example_based_counting - - -def test_countgd_counting() -> None: - img = ski.data.coins() - result = countgd_counting(image=img, prompt="coin") - assert len(result) == 24 - - -def test_countgd_example_based_counting() -> None: - img = ski.data.coins() - result = countgd_example_based_counting( - visual_prompts=[[85, 106, 122, 145]], - image=img, - ) - assert len(result) == 24 diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 6943a0ff..67f78307 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -700,22 +700,18 @@ def countgd_counting( {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, ] """ - buffer_bytes = numpy_to_bytes(image) - files = [("image", buffer_bytes)] + image_b64 = convert_to_b64(image) prompt = prompt.replace(", ", " .") - payload = {"prompts": [prompt], "model": "countgd"} + payload = {"prompt": prompt, "image": image_b64} metadata = {"function_name": "countgd_counting"} - resp_data = send_task_inference_request( - payload, "text-to-object-detection", files=files, metadata=metadata - ) - bboxes_per_frame = resp_data[0] + resp_data = send_task_inference_request(payload, "countgd", metadata=metadata) bboxes_formatted = [ ODResponseData( label=bbox["label"], - bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])), + bbox=list(map(lambda x: round(x, 2), bbox["bbox"])), score=round(bbox["score"], 2), ) - for bbox in bboxes_per_frame + for bbox in resp_data ] filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) return [bbox.model_dump() for bbox in filtered_bboxes]