From 7b2e87f6288f70dfc657ddb8eaf5edb3f672d7e6 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 10 Oct 2024 18:36:38 -0700 Subject: [PATCH] fix type issues --- vision_agent/agent/vision_agent_coder.py | 9 ++++----- vision_agent/agent/vision_agent_planner.py | 2 +- vision_agent/tools/meta_tools.py | 3 ++- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index a265ec22..a682a31c 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -14,7 +14,6 @@ _MAX_TABULATE_COL_WIDTH, DefaultImports, extract_code, - extract_json, extract_tag, format_memory, print_code, @@ -85,7 +84,7 @@ def strip_function_calls(code: str, exclusions: Optional[List[str]] = None) -> s for node in nodes_to_remove: node.parent.remove(node) cleaned_code = red.dumps().strip() - return cleaned_code + return cleaned_code if isinstance(cleaned_code, str) else code def write_code( @@ -286,8 +285,8 @@ def debug_code( stream=False, ) fixed_code_and_test_str = cast(str, fixed_code_and_test_str) - thoughts = extract_tag(fixed_code_and_test_str, "thoughts") - thoughts = thoughts if thoughts is not None else "" + thoughts_tag = extract_tag(fixed_code_and_test_str, "thoughts") + thoughts = thoughts_tag if thoughts_tag is not None else "" fixed_code = extract_tag(fixed_code_and_test_str, "code") fixed_test = extract_tag(fixed_code_and_test_str, "test") @@ -312,7 +311,7 @@ def debug_code( new_working_memory.append( { "code": f"{code}\n{test}", - "feedback": cast(str, thoughts), + "feedback": thoughts, "edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"), } ) diff --git a/vision_agent/agent/vision_agent_planner.py b/vision_agent/agent/vision_agent_planner.py index fb746e86..09d9090a 100644 --- a/vision_agent/agent/vision_agent_planner.py +++ b/vision_agent/agent/vision_agent_planner.py @@ -124,7 +124,7 @@ def write_plans( count = 0 while not _check_plan_format(plans) and count < 3: - _LOGGER.info(f"Invalid plan format. Retrying.") + _LOGGER.info("Invalid plan format. Retrying.") plans = extract_json(model(chat, stream=False)) # type: ignore count += 1 if count == 3: diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 32d6d8e1..8706d2a0 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -731,7 +731,8 @@ def use_extra_vision_agent_args( if custom_tool_names is not None: node.value[1].value.append(f"custom_tool_names={custom_tool_names}") - return red.dumps().strip() + cleaned_code = red.dumps().strip() + return cleaned_code if isinstance(cleaned_code, str) else code def use_object_detection_fine_tuning(