From 9e279146c83e6e67cc24fe17826eaa3c2fd57ddc Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 23 Feb 2024 09:05:26 -0800 Subject: [PATCH] added llava generate --- lmm_tools/config.py | 2 ++ lmm_tools/lmm/lmm.py | 18 ++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) create mode 100644 lmm_tools/config.py diff --git a/lmm_tools/config.py b/lmm_tools/config.py new file mode 100644 index 00000000..5d32d75d --- /dev/null +++ b/lmm_tools/config.py @@ -0,0 +1,2 @@ +BASETEN_API_KEY = "PRxjuebe.VQJQ7rCvswimP5y8GeSmZA03I4zw6dgB" +BASETEN_URL = "https://model-232pg41q.api.baseten.co/production/predict" diff --git a/lmm_tools/lmm/lmm.py b/lmm_tools/lmm/lmm.py index c5615114..2bef55fb 100644 --- a/lmm_tools/lmm/lmm.py +++ b/lmm_tools/lmm/lmm.py @@ -1,7 +1,9 @@ import base64 +import requests from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, List, Optional, Union, cast +from lmm_tools.config import BASETEN_API_KEY, BASETEN_URL def encode_image(image: Union[str, Path]) -> str: @@ -12,7 +14,7 @@ def encode_image(image: Union[str, Path]) -> str: class LMM(ABC): @abstractmethod - def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str: + def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: pass @@ -22,8 +24,16 @@ class LLaVALMM(LMM): def __init__(self, name: str): self.name = name - def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str: - raise NotImplementedError("LLaVA LMM not implemented yet") + def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: + data = {"prompt": prompt} + if image: + data["image"] = encode_image(image) + res = requests.post( + BASETEN_URL, + headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"}, + json=data, + ) + return res.text class OpenAILMM(LMM): @@ -35,7 +45,7 @@ def __init__(self, name: str): self.name = name self.client = OpenAI() - def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str: + def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: message: List[Dict[str, Any]] = [ { "role": "user",