From 1596131ae16f26a711431dbcbab24159456113d6 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sun, 8 Sep 2024 18:56:26 -0700 Subject: [PATCH] added ClaudeVisionAgentCoder and fixed json parser --- vision_agent/agent/agent_utils.py | 48 ++++++++----- vision_agent/agent/vision_agent_coder.py | 89 +++++++++++++++++++----- 2 files changed, 102 insertions(+), 35 deletions(-) diff --git a/vision_agent/agent/agent_utils.py b/vision_agent/agent/agent_utils.py index eb951ccc..06fa79a4 100644 --- a/vision_agent/agent/agent_utils.py +++ b/vision_agent/agent/agent_utils.py @@ -14,6 +14,10 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]: if match: json_str = match.group() try: + # remove trailing comma + trailing_bracket_pattern = r",\s+\}" + json_str = re.sub(trailing_bracket_pattern, "}", json_str, flags=re.DOTALL) + json_dict = json.loads(json_str) return json_dict # type: ignore except json.JSONDecodeError: @@ -21,29 +25,37 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]: return None +def _find_markdown_json(json_str: str) -> str: + pattern = r"```json(.*?)```" + match = re.search(pattern, json_str, re.DOTALL) + if match: + return match.group(1).strip() + return json_str + + +def _strip_markdown_code(inp_str: str) -> str: + pattern = r"```python.*?```" + cleaned_str = re.sub(pattern, "", inp_str, flags=re.DOTALL) + return cleaned_str + + def extract_json(json_str: str) -> Dict[str, Any]: + json_str = json_str.replace("\n", " ").strip() + try: - json_str = json_str.replace("\n", " ") - json_dict = json.loads(json_str) + return json.loads(json_str) except json.JSONDecodeError: - if "```json" in json_str: - json_str = json_str[json_str.find("```json") + len("```json") :] - json_str = json_str[: json_str.find("```")] - elif "```" in json_str: - json_str = json_str[json_str.find("```") + len("```") :] - # get the last ``` not one from an intermediate string - json_str = json_str[: json_str.find("}```")] - try: - json_dict = json.loads(json_str) - except json.JSONDecodeError as e: - json_dict = _extract_sub_json(json_str) - if json_dict is not None: - return json_dict # type: ignore - error_msg = f"Could not extract JSON from the given str: {json_str}" + json_orig = json_str + json_str = _strip_markdown_code(json_str) + json_str = _find_markdown_json(json_str) + json_dict = _extract_sub_json(json_str) + + if json_dict is None: + error_msg = f"Could not extract JSON from the given str: {json_orig}" _LOGGER.exception(error_msg) - raise ValueError(error_msg) from e + raise ValueError(error_msg) - return json_dict # type: ignore + return json_dict # type: ignore def extract_code(code: str) -> str: diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 8b6f9032..ce95864d 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -27,7 +27,14 @@ TEST_PLANS, USER_REQ, ) -from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM +from vision_agent.lmm import ( + LMM, + AzureOpenAILMM, + ClaudeSonnetLMM, + Message, + OllamaLMM, + OpenAILMM, +) from vision_agent.tools.meta_tools import get_diff from vision_agent.utils import CodeInterpreterFactory, Execution from vision_agent.utils.execute import CodeInterpreter @@ -168,8 +175,8 @@ def pick_plan( ) tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code)) tool_output_str = "" - if len(tool_output.logs.stdout) > 0: - tool_output_str = tool_output.logs.stdout[0] + if len(tool_output.text().strip()) > 0: + tool_output_str = tool_output.text().strip() if verbosity == 2: _print_code("Initial code and tests:", code) @@ -229,7 +236,7 @@ def pick_plan( if verbosity == 2: _print_code("Code and test after attempted fix:", code) - _LOGGER.info(f"Code execution result after attempt {count}") + _LOGGER.info(f"Code execution result after attempt {count + 1}") count += 1 @@ -387,7 +394,6 @@ def write_and_test_code( "code": DefaultImports.prepend_imports(code), "payload": { "test": test, - # "result": result.to_json(), }, } ) @@ -451,17 +457,32 @@ def debug_code( count = 0 while not success and count < 3: try: - fixed_code_and_test = extract_json( - debugger( # type: ignore - FIX_BUG.format( - code=code, - tests=test, - result="\n".join(result.text().splitlines()[-50:]), - feedback=format_memory(working_memory + new_working_memory), - ), - stream=False, - ) + # LLMs write worse code when it's in JSON, so we have it write JSON + # followed by code each wrapped in markdown blocks. + fixed_code_and_test_str = debugger( # type: ignore + FIX_BUG.format( + code=code, + tests=test, + result="\n".join(result.text().splitlines()[-50:]), + feedback=format_memory(working_memory + new_working_memory), + ), + stream=False, ) + fixed_code_and_test_str = cast(str, fixed_code_and_test_str) + fixed_code_and_test = extract_json(fixed_code_and_test_str) + code = extract_code(fixed_code_and_test_str) + if ( + "which_code" in fixed_code_and_test + and fixed_code_and_test["which_code"] == "test" + ): + fixed_code_and_test["code"] = "" + fixed_code_and_test["test"] = code + else: # for everything else always assume it's updating code + fixed_code_and_test["code"] = code + fixed_code_and_test["test"] = "" + if "which_code" in fixed_code_and_test: + del fixed_code_and_test["which_code"] + success = True except Exception as e: _LOGGER.exception(f"Error while extracting JSON: {e}") @@ -472,9 +493,9 @@ def debug_code( old_test = test if fixed_code_and_test["code"].strip() != "": - code = extract_code(fixed_code_and_test["code"]) + code = fixed_code_and_test["code"] if fixed_code_and_test["test"].strip() != "": - test = extract_code(fixed_code_and_test["test"]) + test = fixed_code_and_test["test"] new_working_memory.append( { @@ -876,6 +897,40 @@ def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None: ) +class ClaudeVisionAgentCoder(VisionAgentCoder): + def __init__( + self, + planner: Optional[LMM] = None, + coder: Optional[LMM] = None, + tester: Optional[LMM] = None, + debugger: Optional[LMM] = None, + tool_recommender: Optional[Sim] = None, + verbosity: int = 0, + report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + code_sandbox_runtime: Optional[str] = None, + ) -> None: + # NOTE: Claude doesn't have an official JSON mode + self.planner = ClaudeSonnetLMM(temperature=0.0) if planner is None else planner + self.coder = ClaudeSonnetLMM(temperature=0.0) if coder is None else coder + self.tester = ClaudeSonnetLMM(temperature=0.0) if tester is None else tester + self.debugger = ( + ClaudeSonnetLMM(temperature=0.0) if debugger is None else debugger + ) + self.verbosity = verbosity + if self.verbosity > 0: + _LOGGER.setLevel(logging.INFO) + + # Anthropic does not offer any embedding models and instead recomends Voyage, + # we're using OpenAI's embedder for now. + self.tool_recommender = ( + Sim(T.TOOLS_DF, sim_key="desc") + if tool_recommender is None + else tool_recommender + ) + self.report_progress_callback = report_progress_callback + self.code_sandbox_runtime = code_sandbox_runtime + + class OllamaVisionAgentCoder(VisionAgentCoder): """VisionAgentCoder that uses Ollama models for planning, coding, testing.