From 075aa2f6db7d7618f6de57a3cc1d81355b7d4db9 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 1 Apr 2024 13:43:03 -0700 Subject: [PATCH] updated params for llm/lmm (#34) * updated params for llm/lmm * type checking fixes --- vision_agent/agent/vision_agent.py | 25 +++++++++++++++---------- vision_agent/llm/llm.py | 17 ++++++++++------- vision_agent/lmm/lmm.py | 17 ++++++++++++++--- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index e9e6d66d..1aed604f 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -256,7 +256,6 @@ def retrieval( ) if tool_id is None: return {}, "" - _LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})") tool_instructions = tools[tool_id] tool_usage = tool_instructions["usage"] @@ -265,7 +264,6 @@ def retrieval( parameters = choose_parameter( model, question, tool_usage, previous_log, reflections ) - _LOGGER.info(f"\tParameters: {parameters} for {tool_name}") if parameters is None: return {}, "" tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters} @@ -290,7 +288,7 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: tool_results["call_results"] = call_results call_results_str = str(call_results) - _LOGGER.info(f"\tCall Results: {call_results_str}") + # _LOGGER.info(f"\tCall Results: {call_results_str}") return tool_results, call_results_str @@ -344,7 +342,9 @@ def self_reflect( def parse_reflect(reflect: str) -> bool: # GPT-4V has a hard time following directions, so make the criteria less strict - return "finish" in reflect.lower() and len(reflect) < 100 + return ( + "finish" in reflect.lower() and len(reflect) < 100 + ) or "finish" in reflect.lower()[-10:] def visualize_result(all_tool_results: List[Dict]) -> List[str]: @@ -423,10 +423,16 @@ def __init__( verbose: bool = False, ): self.task_model = ( - OpenAILLM(json_mode=True) if task_model is None else task_model + OpenAILLM(json_mode=True, temperature=0.1) + if task_model is None + else task_model + ) + self.answer_model = ( + OpenAILLM(temperature=0.1) if answer_model is None else answer_model + ) + self.reflect_model = ( + OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model ) - self.answer_model = OpenAILLM() if answer_model is None else answer_model - self.reflect_model = OpenAILMM() if reflect_model is None else reflect_model self.max_retries = max_retries self.tools = TOOLS @@ -466,7 +472,6 @@ def chat_with_workflow( for _ in range(self.max_retries): task_list = create_tasks(self.task_model, question, self.tools, reflections) - _LOGGER.info(f"Task Dependency: {task_list}") task_depend = {"Original Quesiton": question} previous_log = "" answers = [] @@ -477,7 +482,6 @@ def chat_with_workflow( for task in task_list: task_str = task["task"] previous_log = str(task_depend) - _LOGGER.info(f"\tSubtask: {task_str}") tool_results, call_results = retrieval( self.task_model, task_str, @@ -492,6 +496,7 @@ def chat_with_workflow( tool_results["answer"] = answer all_tool_results.append(tool_results) + _LOGGER.info(f"\tCall Result: {call_results}") _LOGGER.info(f"\tAnswer: {answer}") answers.append({"task": task_str, "answer": answer}) task_depend[task["id"]]["answer"] = answer # type: ignore @@ -510,7 +515,7 @@ def chat_with_workflow( final_answer, visualized_images[0] if len(visualized_images) > 0 else image, ) - _LOGGER.info(f"\tReflection: {reflection}") + _LOGGER.info(f"Reflection: {reflection}") if parse_reflect(reflection): break else: diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 374b58c9..e97bcdeb 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -1,6 +1,6 @@ import json from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Mapping, Union, cast +from typing import Any, Callable, Dict, List, Mapping, Union, cast from openai import OpenAI @@ -31,30 +31,33 @@ class OpenAILLM(LLM): r"""An LLM class for any OpenAI LLM model.""" def __init__( - self, model_name: str = "gpt-4-turbo-preview", json_mode: bool = False + self, + model_name: str = "gpt-4-turbo-preview", + json_mode: bool = False, + **kwargs: Any ): self.model_name = model_name self.client = OpenAI() - self.json_mode = json_mode + self.kwargs = kwargs + if json_mode: + self.kwargs["response_format"] = {"type": "json_object"} def generate(self, prompt: str) -> str: - kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {} response = self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "user", "content": prompt}, ], - **kwargs, # type: ignore + **self.kwargs, ) return cast(str, response.choices[0].message.content) def chat(self, chat: List[Dict[str, str]]) -> str: - kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {} response = self.client.chat.completions.create( model=self.model_name, messages=chat, # type: ignore - **kwargs, + **self.kwargs, ) return cast(str, response.choices[0].message.content) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 48d449f5..7ae65eb2 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -97,11 +97,15 @@ class OpenAILMM(LMM): r"""An LMM class for the OpenAI GPT-4 Vision model.""" def __init__( - self, model_name: str = "gpt-4-vision-preview", max_tokens: int = 1024 + self, + model_name: str = "gpt-4-vision-preview", + max_tokens: int = 1024, + **kwargs: Any, ): self.model_name = model_name self.max_tokens = max_tokens self.client = OpenAI() + self.kwargs = kwargs def __call__( self, @@ -123,6 +127,13 @@ def chat( if image: extension = Path(image).suffix + if extension.lower() == ".jpeg" or extension.lower() == ".jpg": + extension = "jpg" + elif extension.lower() == ".png": + extension = "png" + else: + raise ValueError(f"Unsupported image extension: {extension}") + encoded_image = encode_image(image) fixed_chat[0]["content"].append( # type: ignore { @@ -135,7 +146,7 @@ def chat( ) response = self.client.chat.completions.create( - model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens # type: ignore + model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore ) return cast(str, response.choices[0].message.content) @@ -163,7 +174,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ) response = self.client.chat.completions.create( - model=self.model_name, messages=message, max_tokens=self.max_tokens # type: ignore + model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore ) return cast(str, response.choices[0].message.content)