Skip to content

Commit

Permalink
Add llava generate (#4)
Browse files Browse the repository at this point in the history
added llava generate
  • Loading branch information
dillonalaird authored Feb 23, 2024
1 parent 84b94c6 commit 57097b0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
2 changes: 2 additions & 0 deletions lmm_tools/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
BASETEN_API_KEY = "PRxjuebe.VQJQ7rCvswimP5y8GeSmZA03I4zw6dgB"
BASETEN_URL = "https://model-232pg41q.api.baseten.co/production/predict"
18 changes: 14 additions & 4 deletions lmm_tools/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import requests
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, cast
from lmm_tools.config import BASETEN_API_KEY, BASETEN_URL


def encode_image(image: Union[str, Path]) -> str:
Expand All @@ -12,7 +14,7 @@ def encode_image(image: Union[str, Path]) -> str:

class LMM(ABC):
@abstractmethod
def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
pass


Expand All @@ -22,8 +24,16 @@ class LLaVALMM(LMM):
def __init__(self, name: str):
self.name = name

def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
raise NotImplementedError("LLaVA LMM not implemented yet")
def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
data = {"prompt": prompt}
if image:
data["image"] = encode_image(image)
res = requests.post(
BASETEN_URL,
headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"},
json=data,
)
return res.text


class OpenAILMM(LMM):
Expand All @@ -35,7 +45,7 @@ def __init__(self, name: str):
self.name = name
self.client = OpenAI()

def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
message: List[Dict[str, Any]] = [
{
"role": "user",
Expand Down

0 comments on commit 57097b0

Please sign in to comment.