Skip to content

Commit 3febe3a

Browse files
committed
fix countgd
1 parent 8a8c5d1 commit 3febe3a

File tree

5 files changed

+22
-30
lines changed

5 files changed

+22
-30
lines changed

.github/workflows/ci_cd.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@ jobs:
8383
- name: Test with pytest
8484
run: |
8585
poetry run pytest -v tests/integ
86-
- name: Test with pytest, dev env
87-
run: |
88-
LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev
8986
9087
release:
9188
name: Release

tests/integ/test_tools.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
template_match,
3333
vit_image_classification,
3434
vit_nsfw_classification,
35+
countgd_counting,
36+
countgd_example_based_counting,
3537
)
3638

3739
FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da"
@@ -387,3 +389,18 @@ def test_generate_hed():
387389
)
388390

389391
assert result.shape == img.shape
392+
393+
394+
def test_countgd_counting() -> None:
395+
img = ski.data.coins()
396+
result = countgd_counting(image=img, prompt="coin")
397+
assert len(result) == 24
398+
399+
400+
def test_countgd_example_based_counting() -> None:
401+
img = ski.data.coins()
402+
result = countgd_example_based_counting(
403+
visual_prompts=[[85, 106, 122, 145]],
404+
image=img,
405+
)
406+
assert len(result) == 24

tests/integration_dev/__init__.py

Whitespace-only changes.

tests/integration_dev/test_tools.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

vision_agent/tools/tools.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -700,22 +700,18 @@ def countgd_counting(
700700
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
701701
]
702702
"""
703-
buffer_bytes = numpy_to_bytes(image)
704-
files = [("image", buffer_bytes)]
703+
image_b64 = convert_to_b64(image)
705704
prompt = prompt.replace(", ", " .")
706-
payload = {"prompts": [prompt], "model": "countgd"}
705+
payload = {"prompt": prompt, "image": image_b64}
707706
metadata = {"function_name": "countgd_counting"}
708-
resp_data = send_task_inference_request(
709-
payload, "text-to-object-detection", files=files, metadata=metadata
710-
)
711-
bboxes_per_frame = resp_data[0]
707+
resp_data = send_task_inference_request(payload, "countgd", metadata=metadata)
712708
bboxes_formatted = [
713709
ODResponseData(
714710
label=bbox["label"],
715-
bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])),
711+
bbox=list(map(lambda x: round(x, 2), bbox["bbox"])),
716712
score=round(bbox["score"], 2),
717713
)
718-
for bbox in bboxes_per_frame
714+
for bbox in resp_data
719715
]
720716
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
721717
return [bbox.model_dump() for bbox in filtered_bboxes]

0 commit comments

Comments
 (0)