From 35190edd9ff736d0f03736361dd91d5e899b71a0 Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:04:14 -0700 Subject: [PATCH] refactor: switch model endpoints (#54) * Switch the host of model endpoint to api.dev.landing.ai * DRY/Abstract out the inference code in tools * Introduce LandingaiAPIKey and support loading from .env file * Add integration tests for four model tools * Minor tweaks/fixes * Remove dead code * Bump the minor version to 0.1.0 --- poetry.lock | 121 ++++++++++++++++++++++++++++- pyproject.toml | 4 +- tests/test_tools.py | 42 ++++++++++ tests/test_type_defs.py | 13 ++++ vision_agent/agent/vision_agent.py | 2 +- vision_agent/tools/tools.py | 90 +++++++-------------- vision_agent/type_defs.py | 48 ++++++++++++ 7 files changed, 253 insertions(+), 67 deletions(-) create mode 100644 tests/test_tools.py create mode 100644 tests/test_type_defs.py create mode 100644 vision_agent/type_defs.py diff --git a/poetry.lock b/poetry.lock index a67d2e98..a0c62cab 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -651,6 +651,25 @@ files = [ {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, ] +[[package]] +name = "lazy-loader" +version = "0.4" +description = "Makes it easy to load subpackages and functions on demand." +optional = false +python-versions = ">=3.7" +files = [ + {file = "lazy_loader-0.4-py3-none-any.whl", hash = "sha256:342aa8e14d543a154047afb4ba8ef17f5563baad3fc610d7b15b213b0f119efc"}, + {file = "lazy_loader-0.4.tar.gz", hash = "sha256:47c75182589b91a4e1a85a136c074285a5ad4d9f39c63e0d7fb76391c4574cd1"}, +] + +[package.dependencies] +packaging = "*" + +[package.extras] +dev = ["changelist (==0.5)"] +lint = ["pre-commit (==3.7.0)"] +test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"] + [[package]] name = "markdown" version = "3.6" @@ -1595,6 +1614,25 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pydantic-settings" +version = "2.2.1" +description = "Settings management using Pydantic" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_settings-2.2.1-py3-none-any.whl", hash = "sha256:0235391d26db4d2190cb9b31051c4b46882d28a51533f97440867f012d4da091"}, + {file = "pydantic_settings-2.2.1.tar.gz", hash = "sha256:00b9f6a5e95553590434c0fa01ead0b216c3e10bc54ae02e37f359948643c5ed"}, +] + +[package.dependencies] +pydantic = ">=2.3.0" +python-dotenv = ">=0.21.0" + +[package.extras] +toml = ["tomli (>=2.0.1)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "pyflakes" version = "2.5.0" @@ -1675,6 +1713,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.0.1" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, + {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "pytz" version = "2024.1" @@ -2035,6 +2087,54 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface_hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools_rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "scikit-image" +version = "0.22.0" +description = "Image processing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_image-0.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74ec5c1d4693506842cc7c9487c89d8fc32aed064e9363def7af08b8f8cbb31d"}, + {file = "scikit_image-0.22.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:a05ae4fe03d802587ed8974e900b943275548cde6a6807b785039d63e9a7a5ff"}, + {file = "scikit_image-0.22.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a92dca3d95b1301442af055e196a54b5a5128c6768b79fc0a4098f1d662dee6"}, + {file = "scikit_image-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3663d063d8bf2fb9bdfb0ca967b9ee3b6593139c860c7abc2d2351a8a8863938"}, + {file = "scikit_image-0.22.0-cp310-cp310-win_amd64.whl", hash = "sha256:ebdbdc901bae14dab637f8d5c99f6d5cc7aaf4a3b6f4003194e003e9f688a6fc"}, + {file = "scikit_image-0.22.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:95d6da2d8a44a36ae04437c76d32deb4e3c993ffc846b394b9949fd8ded73cb2"}, + {file = "scikit_image-0.22.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:2c6ef454a85f569659b813ac2a93948022b0298516b757c9c6c904132be327e2"}, + {file = "scikit_image-0.22.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e87872f067444ee90a00dd49ca897208308645382e8a24bd3e76f301af2352cd"}, + {file = "scikit_image-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5c378db54e61b491b9edeefff87e49fcf7fdf729bb93c777d7a5f15d36f743e"}, + {file = "scikit_image-0.22.0-cp311-cp311-win_amd64.whl", hash = "sha256:2bcb74adb0634258a67f66c2bb29978c9a3e222463e003b67ba12056c003971b"}, + {file = "scikit_image-0.22.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:003ca2274ac0fac252280e7179ff986ff783407001459ddea443fe7916e38cff"}, + {file = "scikit_image-0.22.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:cf3c0c15b60ae3e557a0c7575fbd352f0c3ce0afca562febfe3ab80efbeec0e9"}, + {file = "scikit_image-0.22.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5b23908dd4d120e6aecb1ed0277563e8cbc8d6c0565bdc4c4c6475d53608452"}, + {file = "scikit_image-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be79d7493f320a964f8fcf603121595ba82f84720de999db0fcca002266a549a"}, + {file = "scikit_image-0.22.0-cp312-cp312-win_amd64.whl", hash = "sha256:722b970aa5da725dca55252c373b18bbea7858c1cdb406e19f9b01a4a73b30b2"}, + {file = "scikit_image-0.22.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:22318b35044cfeeb63ee60c56fc62450e5fe516228138f1d06c7a26378248a86"}, + {file = "scikit_image-0.22.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:9e801c44a814afdadeabf4dffdffc23733e393767958b82319706f5fa3e1eaa9"}, + {file = "scikit_image-0.22.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c472a1fb3665ec5c00423684590631d95f9afcbc97f01407d348b821880b2cb3"}, + {file = "scikit_image-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b7a6c89e8d6252332121b58f50e1625c35f7d6a85489c0b6b7ee4f5155d547a"}, + {file = "scikit_image-0.22.0-cp39-cp39-win_amd64.whl", hash = "sha256:5071b8f6341bfb0737ab05c8ab4ac0261f9e25dbcc7b5d31e5ed230fd24a7929"}, + {file = "scikit_image-0.22.0.tar.gz", hash = "sha256:018d734df1d2da2719087d15f679d19285fce97cd37695103deadfaef2873236"}, +] + +[package.dependencies] +imageio = ">=2.27" +lazy_loader = ">=0.3" +networkx = ">=2.8" +numpy = ">=1.22" +packaging = ">=21" +pillow = ">=9.0.1" +scipy = ">=1.8" +tifffile = ">=2022.8.12" + +[package.extras] +build = ["Cython (>=0.29.32)", "build", "meson-python (>=0.14)", "ninja", "numpy (>=1.22)", "packaging (>=21)", "pythran", "setuptools (>=67)", "spin (==0.6)", "wheel"] +data = ["pooch (>=1.6.0)"] +developer = ["pre-commit", "tomli"] +docs = ["PyWavelets (>=1.1.1)", "dask[array] (>=2022.9.2)", "ipykernel", "ipywidgets", "kaleido", "matplotlib (>=3.5)", "myst-parser", "numpydoc (>=1.6)", "pandas (>=1.5)", "plotly (>=5.10)", "pooch (>=1.6)", "pydata-sphinx-theme (>=0.14.1)", "pytest-runner", "scikit-learn (>=1.1)", "seaborn (>=0.11)", "sphinx (>=7.2)", "sphinx-copybutton", "sphinx-gallery (>=0.14)", "sphinx_design (>=0.5)", "tifffile (>=2022.8.12)"] +optional = ["PyWavelets (>=1.1.1)", "SimpleITK", "astropy (>=5.0)", "cloudpickle (>=0.2.1)", "dask[array] (>=2021.1.0)", "matplotlib (>=3.5)", "pooch (>=1.6.0)", "pyamg", "scikit-learn (>=1.1)"] +test = ["asv", "matplotlib (>=3.5)", "numpydoc (>=1.5)", "pooch (>=1.6.0)", "pytest (>=7.0)", "pytest-cov (>=2.11.0)", "pytest-faulthandler", "pytest-localserver"] + [[package]] name = "scikit-learn" version = "1.4.1.post1" @@ -2217,6 +2317,23 @@ files = [ {file = "threadpoolctl-3.4.0.tar.gz", hash = "sha256:f11b491a03661d6dd7ef692dd422ab34185d982466c49c8f98c8f716b5c93196"}, ] +[[package]] +name = "tifffile" +version = "2024.2.12" +description = "Read and write TIFF files" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tifffile-2024.2.12-py3-none-any.whl", hash = "sha256:870998f82fbc94ff7c3528884c1b0ae54863504ff51dbebea431ac3fa8fb7c21"}, + {file = "tifffile-2024.2.12.tar.gz", hash = "sha256:4920a3ec8e8e003e673d3c6531863c99eedd570d1b8b7e141c072ed78ff8030d"}, +] + +[package.dependencies] +numpy = "*" + +[package.extras] +all = ["defusedxml", "fsspec", "imagecodecs (>=2023.8.12)", "lxml", "matplotlib", "zarr"] + [[package]] name = "tokenizers" version = "0.15.2" @@ -2677,4 +2794,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "dd185bf30f7f4f00b7cece1bdb9c5183b07ca8544982c4a630c6da281c5d2ae7" +content-hash = "37e5b5d42e2c18f1d5741fc2efd9aa4d31dbd4413d09a7d43462e6a1a669531d" diff --git a/pyproject.toml b/pyproject.toml index 7fca8c0a..e741a4c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "vision-agent" -version = "0.0.53" +version = "0.1.0" description = "Toolset for Vision Agent" authors = ["Landing AI "] readme = "README.md" @@ -31,6 +31,7 @@ typing_extensions = "4.*" moviepy = "1.*" opencv-python-headless = "4.*" tabulate = "^0.9.0" +pydantic-settings = "^2.2.1" [tool.poetry.group.dev.dependencies] autoflake = "1.*" @@ -49,6 +50,7 @@ mkdocs = "^1.5.3" mkdocstrings = {extras = ["python"], version = "^0.23.0"} mkdocs-material = "^9.4.2" types-tabulate = "^0.9.0.20240106" +scikit-image = "<0.23.1" [tool.pytest.ini_options] log_cli = true diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 00000000..ab2d20b7 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,42 @@ +import skimage as ski +from PIL import Image + +from vision_agent.tools.tools import CLIP, GroundingDINO, GroundingSAM, ImageCaption + + +def test_grounding_dino(): + img = Image.fromarray(ski.data.coins()) + result = GroundingDINO()( + prompt="coin", + image=img, + ) + assert result["labels"] == ["coin"] * 24 + assert len(result["bboxes"]) == 24 + assert len(result["scores"]) == 24 + + +def test_grounding_sam(): + img = Image.fromarray(ski.data.coins()) + result = GroundingSAM()( + prompt="coin", + image=img, + ) + assert result["labels"] == ["coin"] * 24 + assert len(result["bboxes"]) == 24 + assert len(result["scores"]) == 24 + assert len(result["masks"]) == 24 + + +def test_clip(): + img = Image.fromarray(ski.data.coins()) + result = CLIP()( + prompt="coins", + image=img, + ) + assert result["scores"] == [1.0] + + +def test_image_caption(): + img = Image.fromarray(ski.data.coins()) + result = ImageCaption()(image=img) + assert result["text"] == ["a black and white photo of a coin"] diff --git a/tests/test_type_defs.py b/tests/test_type_defs.py new file mode 100644 index 00000000..273681f6 --- /dev/null +++ b/tests/test_type_defs.py @@ -0,0 +1,13 @@ +import os + +from vision_agent.type_defs import LandingaiAPIKey + + +def test_load_api_credential_from_env_var(): + actual = LandingaiAPIKey() + assert actual.api_key is not None + + os.environ["landingai_api_key"] = "land_sk_123" + actual = LandingaiAPIKey() + assert actual.api_key == "land_sk_123" + del os.environ["landingai_api_key"] diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 67bce0a1..10c98735 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -476,7 +476,7 @@ def chat_with_workflow( reflections += "\n" + reflection # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. self.log_progress( - f"The Vision Agent has concluded this chat. {final_answer}" + f"The Vision Agent has concluded this chat. {final_answer}" ) if visualize_output: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 2c686c43..792f1ba1 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -12,8 +12,11 @@ from vision_agent.image_utils import convert_to_b64, get_image_size from vision_agent.tools.video import extract_frames_from_video +from vision_agent.type_defs import LandingaiAPIKey _LOGGER = logging.getLogger(__name__) +_LND_API_KEY = LandingaiAPIKey().api_key +_LND_API_URL = "https://api.dev.landing.ai/v1/agent" def normalize_bbox( @@ -80,8 +83,6 @@ class CLIP(Tool): [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}] """ - _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" - name = "clip_" description = "'clip_' is a tool that can classify any image given a set of input names or tags. It returns a list of the input names along with their probability scores." usage = { @@ -125,23 +126,9 @@ def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: "image": image_b64, "tool": "closed_set_image_classification", } - res = requests.post( - self._ENDPOINT, - headers={"Content-Type": "application/json"}, - json=data, - ) - resp_json: Dict[str, Any] = res.json() - if ( - "statusCode" in resp_json and resp_json["statusCode"] != 200 - ) or "statusCode" not in resp_json: - _LOGGER.error(f"Request failed: {resp_json}") - raise ValueError(f"Request failed: {resp_json}") - - resp_json["data"]["scores"] = [ - round(prob, 4) for prob in resp_json["data"]["scores"] - ] - - return resp_json["data"] # type: ignore + resp_data = _send_inference_request(data, "tools") + resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] + return resp_data class ImageCaption(Tool): @@ -156,8 +143,6 @@ class ImageCaption(Tool): {'text': ['a box of orange and white socks']} """ - _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" - name = "image_caption_" description = "'image_caption_' is a tool that can caption an image based on its contents or tags. It returns a text describing the image" usage = { @@ -197,19 +182,7 @@ def __call__(self, image: Union[str, ImageType]) -> Dict: "image": image_b64, "tool": "image_captioning", } - res = requests.post( - self._ENDPOINT, - headers={"Content-Type": "application/json"}, - json=data, - ) - resp_json: Dict[str, Any] = res.json() - if ( - "statusCode" in resp_json and resp_json["statusCode"] != 200 - ) or "statusCode" not in resp_json: - _LOGGER.error(f"Request failed: {resp_json}") - raise ValueError(f"Request failed: {resp_json}") - - return resp_json["data"] # type: ignore + return _send_inference_request(data, "tools") class GroundingDINO(Tool): @@ -226,8 +199,6 @@ class GroundingDINO(Tool): 'scores': [0.98, 0.02]}] """ - _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" - name = "grounding_dino_" description = "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions. It returns a list of bounding boxes, label names and associated probability scores." usage = { @@ -290,24 +261,13 @@ def __call__( "tool": "visual_grounding", "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } - res = requests.post( - self._ENDPOINT, - headers={"Content-Type": "application/json"}, - json=request_data, - ) - resp_json: Dict[str, Any] = res.json() - if ( - "statusCode" in resp_json and resp_json["statusCode"] != 200 - ) or "statusCode" not in resp_json: - _LOGGER.error(f"Request failed: {resp_json}") - raise ValueError(f"Request failed: {resp_json}") - data: Dict[str, Any] = resp_json["data"] + data: Dict[str, Any] = _send_inference_request(request_data, "tools") if "bboxes" in data: data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]] if "scores" in data: data["scores"] = [round(score, 2) for score in data["scores"]] if "labels" in data: - data["labels"] = [label for label in data["labels"]] + data["labels"] = list(data["labels"]) data["size"] = (image_size[1], image_size[0]) return data @@ -335,8 +295,6 @@ class GroundingSAM(Tool): [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}] """ - _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" - name = "grounding_sam_" description = "'grounding_sam_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores." usage = { @@ -399,18 +357,7 @@ def __call__( "tool": "visual_grounding_segment", "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } - res = requests.post( - self._ENDPOINT, - headers={"Content-Type": "application/json"}, - json=request_data, - ) - resp_json: Dict[str, Any] = res.json() - if ( - "statusCode" in resp_json and resp_json["statusCode"] != 200 - ) or "statusCode" not in resp_json: - _LOGGER.error(f"Request failed: {resp_json}") - raise ValueError(f"Request failed: {resp_json}") - data: Dict[str, Any] = resp_json["data"] + data: Dict[str, Any] = _send_inference_request(request_data, "tools") ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} if "bboxes" in data: ret_pred["bboxes"] = [ @@ -714,3 +661,20 @@ def __call__(self, equation: str) -> float: ) if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage")) } + + +def _send_inference_request( + payload: Dict[str, Any], endpoint_name: str +) -> Dict[str, Any]: + res = requests.post( + f"{_LND_API_URL}/model/{endpoint_name}", + headers={ + "Content-Type": "application/json", + "apikey": _LND_API_KEY, + }, + json=payload, + ) + if res.status_code != 200: + _LOGGER.error(f"Request failed: {res.text}") + raise ValueError(f"Request failed: {res.text}") + return res.json()["data"] # type: ignore diff --git a/vision_agent/type_defs.py b/vision_agent/type_defs.py new file mode 100644 index 00000000..1667bf7a --- /dev/null +++ b/vision_agent/type_defs.py @@ -0,0 +1,48 @@ +from pydantic import Field, field_validator +from pydantic_settings import BaseSettings + + +class LandingaiAPIKey(BaseSettings): + """The API key of a user in a particular organization in LandingLens. + It supports loading from environment variables or .env files. + The supported name of the environment variables are (case-insensitive): + - LANDINGAI_API_KEY + + Environment variables will always take priority over values loaded from a dotenv file. + """ + + api_key: str = Field( + default="land_sk_hw34v3tyEc35OAhP8F7hnGnrDv2C8hD2ycMyq0aMkVS1H40D22", + alias="LANDINGAI_API_KEY", + description="The API key of LandingAI.", + ) + + @field_validator("api_key") + @classmethod + def is_api_key_valid(cls, key: str) -> str: + """Check if the API key is a v2 key.""" + if not key: + raise InvalidApiKeyError(f"LandingAI API key is required, but it's {key}") + if not key.startswith("land_sk_"): + raise InvalidApiKeyError( + f"LandingAI API key (v2) must start with 'land_sk_' prefix, but it's {key}. See https://support.landing.ai/docs/api-key for more information." + ) + return key + + class Config: + env_file = ".env" + env_prefix = "landingai_" + case_sensitive = False + extra = "ignore" + + +class InvalidApiKeyError(Exception): + """Exception raised when the an invalid API key is provided. This error could be raised from any SDK code, not limited to a HTTP client.""" + + def __init__(self, message: str): + self.message = f"""{message} +For more information, see https://landing-ai.github.io/landingai-python/landingai.html#manage-api-credentials""" + super().__init__(self.message) + + def __str__(self) -> str: + return self.message