From 4a3c571404ca115ba5ebd6544e9a1f472169abd7 Mon Sep 17 00:00:00 2001 From: Shankar <90070882+shankar-vision-eng@users.noreply.github.com> Date: Wed, 29 May 2024 13:46:41 -0700 Subject: [PATCH] V3 improvments set2 (#99) * changes for fixing errors in test case execution and using LMM for planning * fix linting * returning dictionary instead of code --- vision_agent/agent/agent_coder.py | 12 +++++++++--- vision_agent/agent/vision_agent.py | 17 +++++++++++++---- vision_agent/agent/vision_agent_prompts.py | 12 +++++++----- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/vision_agent/agent/agent_coder.py b/vision_agent/agent/agent_coder.py index bba539f2..980c28d4 100644 --- a/vision_agent/agent/agent_coder.py +++ b/vision_agent/agent/agent_coder.py @@ -67,11 +67,17 @@ def parse_file_name(s: str) -> str: return "".join([p for p in s.split(" ") if p.endswith(".png")]) -def write_program(question: str, feedback: str, model: LLM) -> str: +def write_program( + question: str, feedback: str, model: LLM, media: Optional[Union[str, Path]] = None +) -> str: prompt = PROGRAM.format( docstring=TOOL_DOCSTRING, question=question, feedback=feedback ) - completion = model(prompt) + if isinstance(model, OpenAILMM): + completion = model(prompt, images=[media] if media else None) + else: + completion = model(prompt) + return preprocess_data(completion) @@ -168,7 +174,7 @@ def chat( code = "" feedback = "" for _ in range(self.max_turns): - code = write_program(question, feedback, self.coder_agent) + code = write_program(question, feedback, self.coder_agent, media=media) if self.verbose: _CONSOLE.print( Syntax(code, "python", theme="gruvbox-dark", line_numbers=True) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index f45ca7b4..d5e65169 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -21,6 +21,7 @@ USER_REQ, ) from vision_agent.llm import LLM, OpenAILLM +from vision_agent.lmm import LMM, OpenAILMM from vision_agent.tools import TOOL_DESCRIPTIONS, TOOLS_DF, UTILITIES_DOCSTRING from vision_agent.utils import Execute from vision_agent.utils.sim import Sim @@ -77,7 +78,8 @@ def write_plan( chat: List[Dict[str, str]], tool_desc: str, working_memory: str, - model: LLM, + model: Union[LLM, LMM], + media: Optional[List[Union[str, Path]]] = None, ) -> List[Dict[str, str]]: chat = copy.deepcopy(chat) if chat[-1]["role"] != "user": @@ -87,7 +89,10 @@ def write_plan( context = USER_REQ.format(user_request=user_request) prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory) chat[-1]["content"] = prompt - return extract_json(model.chat(chat))["plan"] # type: ignore + if isinstance(model, OpenAILMM): + return extract_json(model.chat(chat, images=media))["plan"] # type: ignore + else: + return extract_json(model.chat(chat))["plan"] # type: ignore def reflect( @@ -324,7 +329,7 @@ def __call__( input = [{"role": "user", "content": input}] results = self.chat_with_workflow(input, media) results.pop("working_memory") - return results["code"] # type: ignore + return results # type: ignore def chat_with_workflow( self, @@ -363,7 +368,11 @@ def chat_with_workflow( while not success and retries < self.max_retries: plan_i = write_plan( - chat, TOOL_DESCRIPTIONS, format_memory(working_memory), self.planner + chat, + TOOL_DESCRIPTIONS, + format_memory(working_memory), + self.planner, + media=[media] if media else None, ) plan_i_str = "\n-".join([e["instructions"] for e in plan_i]) if self.verbosity >= 1: diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index 6041bfc3..45e0310d 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -169,11 +169,13 @@ def find_text(image_path: str, text: str) -> str: 1. Verify the fundamental functionality under normal conditions. 2. Ensure each test case is well-documented with comments explaining the scenario it covers. 3. Your test case MUST run only on the given image which is {media} -4. DO NOT use any non-existent or dummy image or video files that are not provided by the user's instructions. -5. DO NOT mock any functions, you must test their functionality as is. -6. DO NOT assert the output value, run the code and verify it runs without any errors and assert only the output format or data structure. -7. DO NOT import the testing function as it will available in the testing environment. -8. Print the output of the function that is being tested. +4. Your test case MUST run only with the given values which is available in the question - {question} +5. DO NOT use any non-existent or dummy image or video files that are not provided by the user's instructions. +6. DO NOT mock any functions, you must test their functionality as is. +7. DO NOT assert the output value, run the code and assert only the output format or data structure. +8. DO NOT use try except block to handle the error, let the error be raised if the code is incorrect. +9. DO NOT import the testing function as it will available in the testing environment. +10. Print the output of the function that is being tested. """