Skip to content

Commit

Permalink
fix countgd
Browse files Browse the repository at this point in the history
  • Loading branch information
Dayof committed Oct 4, 2024
1 parent 8a8c5d1 commit 3febe3a
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 30 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ jobs:
- name: Test with pytest
run: |
poetry run pytest -v tests/integ
- name: Test with pytest, dev env
run: |
LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev
release:
name: Release
Expand Down
17 changes: 17 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
template_match,
vit_image_classification,
vit_nsfw_classification,
countgd_counting,
countgd_example_based_counting,
)

FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da"
Expand Down Expand Up @@ -387,3 +389,18 @@ def test_generate_hed():
)

assert result.shape == img.shape


def test_countgd_counting() -> None:
img = ski.data.coins()
result = countgd_counting(image=img, prompt="coin")
assert len(result) == 24


def test_countgd_example_based_counting() -> None:
img = ski.data.coins()
result = countgd_example_based_counting(
visual_prompts=[[85, 106, 122, 145]],
image=img,
)
assert len(result) == 24
Empty file removed tests/integration_dev/__init__.py
Empty file.
18 changes: 0 additions & 18 deletions tests/integration_dev/test_tools.py

This file was deleted.

14 changes: 5 additions & 9 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,22 +700,18 @@ def countgd_counting(
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
]
"""
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
image_b64 = convert_to_b64(image)
prompt = prompt.replace(", ", " .")
payload = {"prompts": [prompt], "model": "countgd"}
payload = {"prompt": prompt, "image": image_b64}
metadata = {"function_name": "countgd_counting"}
resp_data = send_task_inference_request(
payload, "text-to-object-detection", files=files, metadata=metadata
)
bboxes_per_frame = resp_data[0]
resp_data = send_task_inference_request(payload, "countgd", metadata=metadata)
bboxes_formatted = [
ODResponseData(
label=bbox["label"],
bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])),
bbox=list(map(lambda x: round(x, 2), bbox["bbox"])),
score=round(bbox["score"], 2),
)
for bbox in bboxes_per_frame
for bbox in resp_data
]
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
return [bbox.model_dump() for bbox in filtered_bboxes]
Expand Down

0 comments on commit 3febe3a

Please sign in to comment.