From 3af591da70a51e660d696d2443df4f1e0b97be10 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 15 Apr 2024 14:19:23 -0700 Subject: [PATCH] added json mode for lmm, upgraded gpt-4-turbo --- vision_agent/agent/vision_agent.py | 24 +++++++++++++++--------- vision_agent/llm/llm.py | 2 +- vision_agent/lmm/lmm.py | 12 ++++++++---- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index a04e65eb..8e02db0a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -255,7 +255,7 @@ def self_reflect( ) -> str: prompt = VISION_AGENT_REFLECTION.format( question=question, - tools=format_tools(tools), + tools=format_tools({k: v["description"] for k, v in tools.items()}), tool_results=str(tool_result), final_answer=final_answer, ) @@ -268,11 +268,16 @@ def self_reflect( return reflect_model(prompt) -def parse_reflect(reflect: str) -> bool: - # GPT-4V has a hard time following directions, so make the criteria less strict - return ( +def parse_reflect(reflect: str) -> Dict[str, Any]: + try: + return parse_json(reflect) + except Exception: + _LOGGER.error(f"Failed parse json reflection: {reflect}") + # LMMs have a hard time following directions, so make the criteria less strict + finish = ( "finish" in reflect.lower() and len(reflect) < 100 ) or "finish" in reflect.lower()[-10:] + return {"Finish": finish, "Reflection": reflect} def visualize_result(all_tool_results: List[Dict]) -> List[str]: @@ -389,7 +394,7 @@ def __init__( OpenAILLM(temperature=0.1) if answer_model is None else answer_model ) self.reflect_model = ( - OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model + OpenAILMM(json_mode=True, temperature=0.1) if reflect_model is None else reflect_model ) self.max_retries = max_retries self.tools = TOOLS @@ -485,13 +490,14 @@ def chat_with_workflow( visualized_output[0] if len(visualized_output) > 0 else image, ) self.log_progress(f"Reflection: {reflection}") - if parse_reflect(reflection): + parsed_reflection = parse_reflect(reflection) + if parsed_reflection["Finish"]: break else: - reflections += "\n" + reflection - # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. + reflections += "\n" + parsed_reflection["Reflection"] + # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. self.log_progress( - f"The Vision Agent has concluded this chat. {final_answer}" + f"The Vision Agent has concluded this chat. {final_answer}" ) if visualize_output: diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 9022ef73..9352f58b 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -33,7 +33,7 @@ class OpenAILLM(LLM): def __init__( self, - model_name: str = "gpt-4-turbo-preview", + model_name: str = "gpt-4-turbo", api_key: Optional[str] = None, json_mode: bool = False, **kwargs: Any diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 0d63b158..738ac004 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -99,9 +99,10 @@ class OpenAILMM(LMM): def __init__( self, - model_name: str = "gpt-4-vision-preview", + model_name: str = "gpt-4-turbo", api_key: Optional[str] = None, max_tokens: int = 1024, + json_mode: bool = False, **kwargs: Any, ): if not api_key: @@ -111,7 +112,10 @@ def __init__( self.client = OpenAI(api_key=api_key) self.model_name = model_name - self.max_tokens = max_tokens + if "max_tokens" not in kwargs: + kwargs["max_tokens"] = max_tokens + if json_mode: + kwargs["response_format"] = {"type": "json_object"} self.kwargs = kwargs def __call__( @@ -153,7 +157,7 @@ def chat( ) response = self.client.chat.completions.create( - model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore + model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore ) return cast(str, response.choices[0].message.content) @@ -181,7 +185,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ) response = self.client.chat.completions.create( - model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore + model=self.model_name, messages=message, **self.kwargs # type: ignore ) return cast(str, response.choices[0].message.content)