Skip to content

Commit

Permalink
strip installs from code
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 4, 2024
1 parent 031485d commit 7797433
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
22 changes: 21 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from vision_agent.agent.agent_utils import extract_code, extract_json
from vision_agent.agent.agent_utils import (
extract_code,
extract_json,
remove_installs_from_code,
)


def test_basic_json_extract():
Expand Down Expand Up @@ -43,3 +47,19 @@ def test_basic_json_extract():
a_code = extract_code(a)
assert "def test_basic_json_extract():" in a_code
assert "assert extract_json(a) == {" in a_code


def test_remove_installs_from_code():
a = """import os
imoprt sys
!pip install pandas
def test():
print("!pip install dummy")
"""
out = remove_installs_from_code(a)
assert "import os" in out
assert "!pip install pandas" not in out
assert "!pip install dummy" in out
6 changes: 6 additions & 0 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ def extract_code(code: str) -> str:
if code.startswith("python\n"):
code = code[len("python\n") :]
return code


def remove_installs_from_code(code: str) -> str:
pattern = r"\n!pip install.*?(\n|\Z)\n"
code = re.sub(pattern, "", code, flags=re.DOTALL)
return code
10 changes: 7 additions & 3 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@

import vision_agent.tools as T
from vision_agent.agent import Agent
from vision_agent.agent.agent_utils import extract_code, extract_json
from vision_agent.agent.agent_utils import (
extract_code,
extract_json,
remove_installs_from_code,
)
from vision_agent.agent.vision_agent_coder_prompts import (
CODE,
FIX_BUG,
Expand Down Expand Up @@ -836,8 +840,8 @@ def chat_with_workflow(
media=media_list,
)
success = cast(bool, results["success"])
code = cast(str, results["code"])
test = cast(str, results["test"])
code = remove_installs_from_code(cast(str, results["code"]))
test = remove_installs_from_code(cast(str, results["test"]))
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

Expand Down

0 comments on commit 7797433

Please sign in to comment.