diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 4db319f9..73471a30 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,8 @@ -from vision_agent.agent.agent_utils import extract_code, extract_json +from vision_agent.agent.agent_utils import ( + extract_code, + extract_json, + remove_installs_from_code, +) def test_basic_json_extract(): @@ -43,3 +47,19 @@ def test_basic_json_extract(): a_code = extract_code(a) assert "def test_basic_json_extract():" in a_code assert "assert extract_json(a) == {" in a_code + + +def test_remove_installs_from_code(): + a = """import os +imoprt sys + +!pip install pandas + + +def test(): + print("!pip install dummy") +""" + out = remove_installs_from_code(a) + assert "import os" in out + assert "!pip install pandas" not in out + assert "!pip install dummy" in out diff --git a/vision_agent/agent/agent_utils.py b/vision_agent/agent/agent_utils.py index dc0debee..624ad608 100644 --- a/vision_agent/agent/agent_utils.py +++ b/vision_agent/agent/agent_utils.py @@ -77,3 +77,9 @@ def extract_code(code: str) -> str: if code.startswith("python\n"): code = code[len("python\n") :] return code + + +def remove_installs_from_code(code: str) -> str: + pattern = r"\n!pip install.*?(\n|\Z)\n" + code = re.sub(pattern, "", code, flags=re.DOTALL) + return code diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index aa4d83da..1e5030a2 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -13,7 +13,11 @@ import vision_agent.tools as T from vision_agent.agent import Agent -from vision_agent.agent.agent_utils import extract_code, extract_json +from vision_agent.agent.agent_utils import ( + extract_code, + extract_json, + remove_installs_from_code, +) from vision_agent.agent.vision_agent_coder_prompts import ( CODE, FIX_BUG, @@ -836,8 +840,8 @@ def chat_with_workflow( media=media_list, ) success = cast(bool, results["success"]) - code = cast(str, results["code"]) - test = cast(str, results["test"]) + code = remove_installs_from_code(cast(str, results["code"])) + test = remove_installs_from_code(cast(str, results["test"])) working_memory.extend(results["working_memory"]) # type: ignore plan.append({"code": code, "test": test, "plan": plan_i})