Skip to content

Commit

Permalink
refactor: switch model endpoints (#54)
Browse files Browse the repository at this point in the history
* Switch the host of model endpoint to api.dev.landing.ai
* DRY/Abstract out the inference code in tools
* Introduce LandingaiAPIKey and support loading from .env file
* Add integration tests for four model tools
* Minor tweaks/fixes
* Remove dead code
* Bump the minor version to 0.1.0
  • Loading branch information
humpydonkey authored Apr 15, 2024
1 parent 7588639 commit 35190ed
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 67 deletions.
121 changes: 119 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "vision-agent"
version = "0.0.53"
version = "0.1.0"
description = "Toolset for Vision Agent"
authors = ["Landing AI <[email protected]>"]
readme = "README.md"
Expand All @@ -31,6 +31,7 @@ typing_extensions = "4.*"
moviepy = "1.*"
opencv-python-headless = "4.*"
tabulate = "^0.9.0"
pydantic-settings = "^2.2.1"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand All @@ -49,6 +50,7 @@ mkdocs = "^1.5.3"
mkdocstrings = {extras = ["python"], version = "^0.23.0"}
mkdocs-material = "^9.4.2"
types-tabulate = "^0.9.0.20240106"
scikit-image = "<0.23.1"

[tool.pytest.ini_options]
log_cli = true
Expand Down
42 changes: 42 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import skimage as ski
from PIL import Image

from vision_agent.tools.tools import CLIP, GroundingDINO, GroundingSAM, ImageCaption


def test_grounding_dino():
img = Image.fromarray(ski.data.coins())
result = GroundingDINO()(
prompt="coin",
image=img,
)
assert result["labels"] == ["coin"] * 24
assert len(result["bboxes"]) == 24
assert len(result["scores"]) == 24


def test_grounding_sam():
img = Image.fromarray(ski.data.coins())
result = GroundingSAM()(
prompt="coin",
image=img,
)
assert result["labels"] == ["coin"] * 24
assert len(result["bboxes"]) == 24
assert len(result["scores"]) == 24
assert len(result["masks"]) == 24


def test_clip():
img = Image.fromarray(ski.data.coins())
result = CLIP()(
prompt="coins",
image=img,
)
assert result["scores"] == [1.0]


def test_image_caption():
img = Image.fromarray(ski.data.coins())
result = ImageCaption()(image=img)
assert result["text"] == ["a black and white photo of a coin"]
13 changes: 13 additions & 0 deletions tests/test_type_defs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os

from vision_agent.type_defs import LandingaiAPIKey


def test_load_api_credential_from_env_var():
actual = LandingaiAPIKey()
assert actual.api_key is not None

os.environ["landingai_api_key"] = "land_sk_123"
actual = LandingaiAPIKey()
assert actual.api_key == "land_sk_123"
del os.environ["landingai_api_key"]
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def chat_with_workflow(
reflections += "\n" + reflection
# '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
self.log_progress(
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</<ANSWER>"
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
)

if visualize_output:
Expand Down
Loading

0 comments on commit 35190ed

Please sign in to comment.