Skip to content

Commit

Permalink
updated llm interface
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 16, 2024
1 parent 50348dc commit 097d568
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
1 change: 1 addition & 0 deletions vision_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .emb import Embedder, OpenAIEmb, SentenceTransformerEmb, get_embedder
from .llm import LLM, OpenAILLM
from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm
from .agent import Agent
24 changes: 23 additions & 1 deletion vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from abc import ABC, abstractmethod
from typing import Mapping, cast
from typing import Dict, List, Mapping, Union, cast

from openai import OpenAI

from vision_agent.tools import (
Expand All @@ -18,6 +19,14 @@ class LLM(ABC):
def generate(self, prompt: str) -> str:
pass

@abstractmethod
def chat(self, chat: List[Dict[str, str]]) -> str:
pass

@abstractmethod
def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str:
pass


class OpenAILLM(LLM):
r"""An LLM class for any OpenAI LLM model."""
Expand All @@ -36,6 +45,19 @@ def generate(self, prompt: str) -> str:

return cast(str, response.choices[0].message.content)

def chat(self, chat: List[Dict[str, str]]) -> str:
response = self.client.chat.completions.create(
model=self.model_name,
messages=chat, # type: ignore
)

return cast(str, response.choices[0].message.content)

def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str:
if isinstance(input, str):
return self.generate(input)
return self.chat(input)

def generate_classifier(self, prompt: str) -> ImageTool:
prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt)
response = self.client.chat.completions.create(
Expand Down

0 comments on commit 097d568

Please sign in to comment.