From 097d568b07402f182a1c3486c186ead16055ff23 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 15 Mar 2024 17:04:11 -0700 Subject: [PATCH] updated llm interface --- vision_agent/__init__.py | 1 + vision_agent/llm/llm.py | 24 +++++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/vision_agent/__init__.py b/vision_agent/__init__.py index 8d9c25b0..3704c363 100644 --- a/vision_agent/__init__.py +++ b/vision_agent/__init__.py @@ -2,3 +2,4 @@ from .emb import Embedder, OpenAIEmb, SentenceTransformerEmb, get_embedder from .llm import LLM, OpenAILLM from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm +from .agent import Agent diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index af90f1b0..f07b0611 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -1,6 +1,7 @@ import json from abc import ABC, abstractmethod -from typing import Mapping, cast +from typing import Dict, List, Mapping, Union, cast + from openai import OpenAI from vision_agent.tools import ( @@ -18,6 +19,14 @@ class LLM(ABC): def generate(self, prompt: str) -> str: pass + @abstractmethod + def chat(self, chat: List[Dict[str, str]]) -> str: + pass + + @abstractmethod + def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str: + pass + class OpenAILLM(LLM): r"""An LLM class for any OpenAI LLM model.""" @@ -36,6 +45,19 @@ def generate(self, prompt: str) -> str: return cast(str, response.choices[0].message.content) + def chat(self, chat: List[Dict[str, str]]) -> str: + response = self.client.chat.completions.create( + model=self.model_name, + messages=chat, # type: ignore + ) + + return cast(str, response.choices[0].message.content) + + def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str: + if isinstance(input, str): + return self.generate(input) + return self.chat(input) + def generate_classifier(self, prompt: str) -> ImageTool: prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt) response = self.client.chat.completions.create(