Skip to content

Commit

Permalink
fixed bug with edit code errors
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 3, 2024
1 parent 5b3c0f0 commit cc0e866
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from IPython.display import display

Expand Down Expand Up @@ -105,13 +104,14 @@ def load(self, file_path: Union[str, Path]) -> None:

def show(self) -> str:
"""Shows the artifacts that have been loaded and their remote save paths."""
out_str = "[Artifacts loaded]\n"
output_str = "[Artifacts loaded]\n"
for k in self.artifacts.keys():
out_str += (
output_str += (
f"Artifact {k} loaded to {str(self.remote_save_path.parent / k)}\n"
)
out_str += "[End of artifacts]\n"
return out_str
output_str += "[End of artifacts]\n"
print(output_str)
return output_str

def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
save_path = (
Expand Down Expand Up @@ -237,7 +237,7 @@ def edit_code_artifact(
new_content_lines = [
line if line.endswith("\n") else line + "\n" for line in new_content_lines
]
lines = artifacts[name].splitlines()
lines = artifacts[name].splitlines(keepends=True)
edited_lines = lines[:start] + new_content_lines + lines[end:]

cur_line = start + len(content.split("\n")) // 2
Expand Down Expand Up @@ -274,6 +274,7 @@ def edit_code_artifact(
)

error_msg += f"\n[This is how your edit would have looked like if applied]\n{edited_view}\n\n[This is the original code before your edit]\n{original_view}"
print(error_msg)
return error_msg

artifacts[name] = "".join(edited_lines)
Expand Down Expand Up @@ -478,6 +479,16 @@ def use_florence2_fine_tuning(
--------
>>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf")
"""

task_to_fn = {
"phrase_grounding": "florence2_phrase_grounding"
}

if name not in artifacts:
output_str = f"[Artifact {name} does not exist]"
print(output_str)
return output_str

code = artifacts[name]
if task.lower() == "phrase_grounding":
pattern = r'florence2_phrase_grounding\((".*?", .*?)\)'
Expand All @@ -489,6 +500,12 @@ def replacer(match):
raise ValueError(f"Task {task} is not supported.")

new_code = re.sub(pattern, replacer, code)

if new_code == code:
output_str = f"[Fine tuning task {task} function {task_to_fn[task]} not found in code]"
print(output_str)
return output_str

artifacts[name] = new_code

diff = get_diff(code, new_code)
Expand Down

0 comments on commit cc0e866

Please sign in to comment.