Skip to content

Commit

Permalink
Fine tuning support (#219)
Browse files Browse the repository at this point in the history
* moved fine tuning to meta tools

* fix error messages

* move get_diff and add use_florence2_fine_tuning

* add fine tuning arg to florence2

* set notebook execute path to remote path'

* remove comments

* fix bug exec isolation wasn't setting resources

* ensure agent uses print to view results

* fixed bug with edit code errors

* fixed bug with edit code errors, and fixed replace code for fine tune

* add imports for new meta tools

* fixed type errors

* fix format issue

* fixed regex

* fix bug with upload return path
  • Loading branch information
dillonalaird authored Sep 4, 2024
1 parent c64c02a commit 0f2f7aa
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 164 deletions.
16 changes: 10 additions & 6 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BoilerplateCode:
pre_code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact, florence2_fine_tuning, use_florence2_fine_tuning",
"artifacts = Artifacts('{remote_path}')",
"artifacts.load('{remote_path}')",
]
Expand Down Expand Up @@ -76,11 +76,16 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:

def run_code_action(
code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
) -> Execution:
return code_interpreter.exec_isolation(
) -> Tuple[Execution, str]:
result = code_interpreter.exec_isolation(
BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path)
)

obs = str(result.logs)
if result.error:
obs += f"\n{result.error}"
return result, obs


def parse_execution(response: str) -> Optional[str]:
code = None
Expand Down Expand Up @@ -192,7 +197,7 @@ def chat_with_code(
artifacts = Artifacts(WORKSPACE / "artifacts.pkl")

with CodeInterpreterFactory.new_instance(
code_sandbox_runtime=self.code_sandbox_runtime
code_sandbox_runtime=self.code_sandbox_runtime,
) as code_interpreter:
orig_chat = copy.deepcopy(chat)
int_chat = copy.deepcopy(chat)
Expand Down Expand Up @@ -260,10 +265,9 @@ def chat_with_code(
code_action = parse_execution(response["response"])

if code_action is not None:
result = run_code_action(
result, obs = run_code_action(
code_action, code_interpreter, str(remote_artifacts_path)
)
obs = str(result.logs)

if self.verbosity >= 1:
_LOGGER.info(obs)
Expand Down
10 changes: 1 addition & 9 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import difflib
import logging
import os
import sys
Expand Down Expand Up @@ -29,6 +28,7 @@
USER_REQ,
)
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
from vision_agent.tools.meta_tools import get_diff
from vision_agent.utils import CodeInterpreterFactory, Execution
from vision_agent.utils.execute import CodeInterpreter
from vision_agent.utils.image_utils import b64_to_pil
Expand Down Expand Up @@ -63,14 +63,6 @@ def prepend_imports(code: str) -> str:
return DefaultImports.to_code_string() + "\n\n" + code


def get_diff(before: str, after: str) -> str:
return "".join(
difflib.unified_diff(
before.splitlines(keepends=True), after.splitlines(keepends=True)
)
)


def format_memory(memory: List[Dict[str, str]]) -> str:
output_str = ""
for i, m in enumerate(memory):
Expand Down
6 changes: 3 additions & 3 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
4| return dogs
[End of artifact]
AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
OBSERVATION:
----- stdout -----
Expand All @@ -75,7 +75,7 @@
4| return dogs
[End of artifact]
AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
OBSERVATION:
----- stdout -----
Expand Down Expand Up @@ -126,7 +126,7 @@
15| return count
[End of artifact]
AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code to get the output and write the visualization to the artifacts so the user can see it.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}
AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}
OBSERVATION:
----- stdout -----
Expand Down
148 changes: 140 additions & 8 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import difflib
import os
import pickle as pkl
import re
import subprocess
import tempfile
from pathlib import Path
Expand All @@ -8,10 +10,13 @@
from IPython.display import display

import vision_agent as va
from vision_agent.clients.landing_public_api import LandingPublicAPI
from vision_agent.lmm.types import Message
from vision_agent.tools.tool_utils import get_tool_documentation
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
from vision_agent.utils.execute import Execution, MimeType
from vision_agent.utils.image_utils import convert_to_b64

# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent

Expand Down Expand Up @@ -99,13 +104,14 @@ def load(self, file_path: Union[str, Path]) -> None:

def show(self) -> str:
"""Shows the artifacts that have been loaded and their remote save paths."""
out_str = "[Artifacts loaded]\n"
output_str = "[Artifacts loaded]\n"
for k in self.artifacts.keys():
out_str += (
output_str += (
f"Artifact {k} loaded to {str(self.remote_save_path.parent / k)}\n"
)
out_str += "[End of artifacts]\n"
return out_str
output_str += "[End of artifacts]\n"
print(output_str)
return output_str

def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
save_path = (
Expand Down Expand Up @@ -135,7 +141,12 @@ def format_lines(lines: List[str], start_idx: int) -> str:


def view_lines(
lines: List[str], line_num: int, window_size: int, name: str, total_lines: int
lines: List[str],
line_num: int,
window_size: int,
name: str,
total_lines: int,
print_output: bool = True,
) -> str:
start = max(0, line_num - window_size)
end = min(len(lines), line_num + window_size)
Expand All @@ -148,7 +159,9 @@ def view_lines(
else f"[{len(lines) - end} more lines]"
)
)
print(return_str)

if print_output:
print(return_str)
return return_str


Expand Down Expand Up @@ -231,7 +244,7 @@ def edit_code_artifact(
new_content_lines = [
line if line.endswith("\n") else line + "\n" for line in new_content_lines
]
lines = artifacts[name].splitlines()
lines = artifacts[name].splitlines(keepends=True)
edited_lines = lines[:start] + new_content_lines + lines[end:]

cur_line = start + len(content.split("\n")) // 2
Expand Down Expand Up @@ -261,13 +274,20 @@ def edit_code_artifact(
DEFAULT_WINDOW_SIZE,
name,
total_lines,
print_output=False,
)
total_lines_edit = sum(1 for _ in edited_lines)
edited_view = view_lines(
edited_lines, cur_line, DEFAULT_WINDOW_SIZE, name, total_lines_edit
edited_lines,
cur_line,
DEFAULT_WINDOW_SIZE,
name,
total_lines_edit,
print_output=False,
)

error_msg += f"\n[This is how your edit would have looked like if applied]\n{edited_view}\n\n[This is the original code before your edit]\n{original_view}"
print(error_msg)
return error_msg

artifacts[name] = "".join(edited_lines)
Expand Down Expand Up @@ -390,13 +410,122 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
return f"[Media {Path(local_path).name} saved]"


def list_artifacts(artifacts: Artifacts) -> str:
"""Lists all the artifacts that have been loaded into the artifacts object."""
output_str = artifacts.show()
print(output_str)
return output_str


def get_tool_descriptions() -> str:
"""Returns a description of all the tools that `generate_vision_code` has access to.
Helpful for answering questions about what types of vision tasks you can do with
`generate_vision_code`."""
return TOOL_DESCRIPTIONS


def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str:
"""'florence2_fine_tuning' is a tool that fine-tune florence2 to be able to detect
objects in an image based on a given dataset. It returns the fine tuning job id.
Parameters:
bboxes (List[BboxInput]): A list of BboxInput containing the
image path, labels and bounding boxes.
task (str): The florencev2 fine-tuning task. The options are
'phrase_grounding'.
Returns:
UUID: The fine tuning job id, this id will used to retrieve the fine
tuned model.
Example
-------
>>> fine_tuning_job_id = florencev2_fine_tuning(
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
"phrase_grounding"
)
"""
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
task_type = PromptTask[task.upper()]
fine_tuning_request = [
BboxInputBase64(
image=convert_to_b64(bbox_input.image_path),
filename=Path(bbox_input.image_path).name,
labels=bbox_input.labels,
bboxes=bbox_input.bboxes,
)
for bbox_input in bboxes_input
]
landing_api = LandingPublicAPI()
fine_tune_id = str(
landing_api.launch_fine_tuning_job("florencev2", task_type, fine_tuning_request)
)
print(f"[Florence2 fine tuning id: {fine_tune_id}]")
return fine_tune_id


def get_diff(before: str, after: str) -> str:
return "".join(
difflib.unified_diff(
before.splitlines(keepends=True), after.splitlines(keepends=True)
)
)


def use_florence2_fine_tuning(
artifacts: Artifacts, name: str, task: str, fine_tune_id: str
) -> str:
"""Replaces florence2 calls with the fine tuning id. This ensures that the code
utilizes the fined tuned florence2 model. Returns the diff between the original
code and the new code.
Parameters:
artifacts (Artifacts): The artifacts object to edit the code from.
name (str): The name of the artifact to edit.
task (str): The task to fine tune the model for. The options are
'phrase_grounding'.
fine_tune_id (str): The fine tuning job id.
Examples
--------
>>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf")
"""

task_to_fn = {"phrase_grounding": "florence2_phrase_grounding"}

if name not in artifacts:
output_str = f"[Artifact {name} does not exist]"
print(output_str)
return output_str

code = artifacts[name]
if task.lower() == "phrase_grounding":
pattern = r"florence2_phrase_grounding\(\s*([^\)]+)\)"

def replacer(match: re.Match) -> str:
arg = match.group(1) # capture all initial arguments
return f'florence2_phrase_grounding({arg}, "{fine_tune_id}")'

else:
raise ValueError(f"Task {task} is not supported.")

new_code = re.sub(pattern, replacer, code)

if new_code == code:
output_str = (
f"[Fine tuning task {task} function {task_to_fn[task]} not found in code]"
)
print(output_str)
return output_str

artifacts[name] = new_code

diff = get_diff(code, new_code)
print(diff)
return diff


META_TOOL_DOCSTRING = get_tool_documentation(
[
get_tool_descriptions,
Expand All @@ -406,5 +535,8 @@ def get_tool_descriptions() -> str:
generate_vision_code,
edit_vision_code,
write_media_artifact,
florence2_fine_tuning,
use_florence2_fine_tuning,
list_artifacts,
]
)
Loading

0 comments on commit 0f2f7aa

Please sign in to comment.