From e0587c3f2186461926ad5994347d60d21719a146 Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Wed, 28 Feb 2024 22:18:20 +0800 Subject: [PATCH] Use public endpoint instead --- README.md | 4 ++-- vision_agent/config.py | 2 -- vision_agent/lmm/lmm.py | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 8 deletions(-) delete mode 100644 vision_agent/config.py diff --git a/README.md b/README.md index 463e8a50..d40111e1 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ To get started you can create an LMM and start generating text from images. The ```python import vision_agent as va -model = va.lmm.get_model("llava") +model = va.lmm.get_lmm("llava") model.generate("Describe this image", "image.png") >>> "A yellow house with a green lawn." ``` @@ -24,7 +24,7 @@ import pandas as pd df = pd.DataFrame({"image_paths": ["image1.png", "image2.png", "image3.png"]}) ds = va.data.DataStore(df) -ds = ds.add_lmm(va.lmm.get_model("llava")) +ds = ds.add_lmm(va.lmm.get_lmm("llava")) ds = ds.add_embedder(va.emb.get_embedder("sentence-transformer")) ds = ds.add_column("descriptions", "Describe this image.") diff --git a/vision_agent/config.py b/vision_agent/config.py deleted file mode 100644 index 5d32d75d..00000000 --- a/vision_agent/config.py +++ /dev/null @@ -1,2 +0,0 @@ -BASETEN_API_KEY = "PRxjuebe.VQJQ7rCvswimP5y8GeSmZA03I4zw6dgB" -BASETEN_URL = "https://model-232pg41q.api.baseten.co/production/predict" diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 0d31cd3d..205af81f 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -1,11 +1,16 @@ import base64 +import logging from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, List, Optional, Union, cast import requests -from vision_agent.config import BASETEN_API_KEY, BASETEN_URL +logging.basicConfig(level=logging.INFO) + +_LOGGER = logging.getLogger(__name__) + +_LLAVA_ENDPOINT = "https://cpvlqoxw6vhpdro27uhkvceady0kvvqk.lambda-url.us-east-2.on.aws" def encode_image(image: Union[str, Path]) -> str: @@ -31,11 +36,14 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str if image: data["image"] = encode_image(image) res = requests.post( - BASETEN_URL, - headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"}, + _LLAVA_ENDPOINT, + headers={"Content-Type": "application/json"}, json=data, ) - return res.text + resp_json: Dict[str, Any] = res.json() + if resp_json["statusCode"] != 200: + _LOGGER.error(f"Request failed: {resp_json['data']}") + return cast(str, resp_json["data"]) class OpenAILMM(LMM):