From c0ac631709133cdf38f2bfac3b0751b5ae4e1c15 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 11 Mar 2024 19:33:07 -0700 Subject: [PATCH] added llm --- vision_agent/llm/__init__.py | 1 + vision_agent/llm/llm.py | 86 ++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 vision_agent/llm/__init__.py create mode 100644 vision_agent/llm/llm.py diff --git a/vision_agent/llm/__init__.py b/vision_agent/llm/__init__.py new file mode 100644 index 00000000..dd5f5c54 --- /dev/null +++ b/vision_agent/llm/__init__.py @@ -0,0 +1 @@ +from .llm import LLM, OpenAILLM diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py new file mode 100644 index 00000000..587d4b28 --- /dev/null +++ b/vision_agent/llm/llm.py @@ -0,0 +1,86 @@ +import json +from typing import cast +from abc import ABC, abstractmethod + +from vision_agent.tools import ( + CHOOSE_PARAMS, + CLIP, + SYSTEM_PROMPT, + GroundingDINO, + GroundingSAM, + ImageTool, +) + + +class LLM(ABC): + @abstractmethod + def generate(self, prompt: str) -> str: + pass + + +class OpenAILLM(LLM): + r"""An LLM class for any OpenAI LLM model.""" + + def __init__(self, model_name: str = "gpt-4-turbo-preview"): + from openai import OpenAI + + self.model_name = model_name + self.client = OpenAI() + + def generate(self, prompt: str) -> str: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "user", "content": prompt}, + ], + ) + + return cast(str, response.choices[0].message.content) + + def generate_classifier(self, prompt: str) -> ImageTool: + prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt) + response = self.client.chat.completions.create( + model=self.model_name, + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ], + ) + + params = json.loads(cast(str, response.choices[0].message.content))[ + "Parameters" + ] + return CLIP(**params) + + def generate_detector(self, params: str) -> ImageTool: + params = CHOOSE_PARAMS.format(api_doc=GroundingDINO.doc, question=params) + response = self.client.chat.completions.create( + model=self.model_name, + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": params}, + ], + ) + + params = json.loads(cast(str, response.choices[0].message.content))[ + "Parameters" + ] + return GroundingDINO(**params) + + def generate_segmentor(self, params: str) -> ImageTool: + params = CHOOSE_PARAMS.format(api_doc=GroundingSAM.doc, question=params) + response = self.client.chat.completions.create( + model=self.model_name, + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": params}, + ], + ) + + params = json.loads(cast(str, response.choices[0].message.content))[ + "Parameters" + ] + return GroundingSAM(**params)