Skip to content

Commit

Permalink
Use public endpoint instead
Browse files Browse the repository at this point in the history
  • Loading branch information
AsiaCao committed Feb 28, 2024
1 parent 4beea6d commit e0587c3
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ To get started you can create an LMM and start generating text from images. The
```python
import vision_agent as va

model = va.lmm.get_model("llava")
model = va.lmm.get_lmm("llava")
model.generate("Describe this image", "image.png")
>>> "A yellow house with a green lawn."
```
Expand All @@ -24,7 +24,7 @@ import pandas as pd

df = pd.DataFrame({"image_paths": ["image1.png", "image2.png", "image3.png"]})
ds = va.data.DataStore(df)
ds = ds.add_lmm(va.lmm.get_model("llava"))
ds = ds.add_lmm(va.lmm.get_lmm("llava"))
ds = ds.add_embedder(va.emb.get_embedder("sentence-transformer"))

ds = ds.add_column("descriptions", "Describe this image.")
Expand Down
2 changes: 0 additions & 2 deletions vision_agent/config.py

This file was deleted.

16 changes: 12 additions & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import base64
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, cast

import requests

from vision_agent.config import BASETEN_API_KEY, BASETEN_URL
logging.basicConfig(level=logging.INFO)

_LOGGER = logging.getLogger(__name__)

_LLAVA_ENDPOINT = "https://cpvlqoxw6vhpdro27uhkvceady0kvvqk.lambda-url.us-east-2.on.aws"


def encode_image(image: Union[str, Path]) -> str:
Expand All @@ -31,11 +36,14 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
if image:
data["image"] = encode_image(image)
res = requests.post(
BASETEN_URL,
headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"},
_LLAVA_ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
return res.text
resp_json: Dict[str, Any] = res.json()
if resp_json["statusCode"] != 200:
_LOGGER.error(f"Request failed: {resp_json['data']}")
return cast(str, resp_json["data"])


class OpenAILMM(LMM):
Expand Down

0 comments on commit e0587c3

Please sign in to comment.