Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine tuning support #219

Merged
merged 16 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BoilerplateCode:
pre_code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact, florence2_fine_tuning, use_florence2_fine_tuning",
"artifacts = Artifacts('{remote_path}')",
"artifacts.load('{remote_path}')",
]
Expand Down Expand Up @@ -76,11 +76,16 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:

def run_code_action(
code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
) -> Execution:
return code_interpreter.exec_isolation(
) -> Tuple[Execution, str]:
result = code_interpreter.exec_isolation(
BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path)
)

obs = str(result.logs)
if result.error:
obs += f"\n{result.error}"
return result, obs


def parse_execution(response: str) -> Optional[str]:
code = None
Expand Down Expand Up @@ -192,7 +197,7 @@ def chat_with_code(
artifacts = Artifacts(WORKSPACE / "artifacts.pkl")

with CodeInterpreterFactory.new_instance(
code_sandbox_runtime=self.code_sandbox_runtime
code_sandbox_runtime=self.code_sandbox_runtime,
) as code_interpreter:
orig_chat = copy.deepcopy(chat)
int_chat = copy.deepcopy(chat)
Expand Down Expand Up @@ -260,10 +265,9 @@ def chat_with_code(
code_action = parse_execution(response["response"])

if code_action is not None:
result = run_code_action(
result, obs = run_code_action(
code_action, code_interpreter, str(remote_artifacts_path)
)
obs = str(result.logs)

if self.verbosity >= 1:
_LOGGER.info(obs)
Expand Down
10 changes: 1 addition & 9 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import difflib
import logging
import os
import sys
Expand Down Expand Up @@ -29,6 +28,7 @@
USER_REQ,
)
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
from vision_agent.tools.meta_tools import get_diff
from vision_agent.utils import CodeInterpreterFactory, Execution
from vision_agent.utils.execute import CodeInterpreter
from vision_agent.utils.image_utils import b64_to_pil
Expand Down Expand Up @@ -63,14 +63,6 @@ def prepend_imports(code: str) -> str:
return DefaultImports.to_code_string() + "\n\n" + code


def get_diff(before: str, after: str) -> str:
return "".join(
difflib.unified_diff(
before.splitlines(keepends=True), after.splitlines(keepends=True)
)
)


def format_memory(memory: List[Dict[str, str]]) -> str:
output_str = ""
for i, m in enumerate(memory):
Expand Down
6 changes: 3 additions & 3 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
4| return dogs
[End of artifact]

AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}

OBSERVATION:
----- stdout -----
Expand All @@ -75,7 +75,7 @@
4| return dogs
[End of artifact]

AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}

OBSERVATION:
----- stdout -----
Expand Down Expand Up @@ -126,7 +126,7 @@
15| return count
[End of artifact]

AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code to get the output and write the visualization to the artifacts so the user can see it.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}
AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}

OBSERVATION:
----- stdout -----
Expand Down
148 changes: 140 additions & 8 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import difflib
import os
import pickle as pkl
import re
import subprocess
import tempfile
from pathlib import Path
Expand All @@ -8,10 +10,13 @@
from IPython.display import display

import vision_agent as va
from vision_agent.clients.landing_public_api import LandingPublicAPI
from vision_agent.lmm.types import Message
from vision_agent.tools.tool_utils import get_tool_documentation
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
from vision_agent.utils.execute import Execution, MimeType
from vision_agent.utils.image_utils import convert_to_b64

# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent

Expand Down Expand Up @@ -99,13 +104,14 @@ def load(self, file_path: Union[str, Path]) -> None:

def show(self) -> str:
"""Shows the artifacts that have been loaded and their remote save paths."""
out_str = "[Artifacts loaded]\n"
output_str = "[Artifacts loaded]\n"
for k in self.artifacts.keys():
out_str += (
output_str += (
f"Artifact {k} loaded to {str(self.remote_save_path.parent / k)}\n"
)
out_str += "[End of artifacts]\n"
return out_str
output_str += "[End of artifacts]\n"
print(output_str)
return output_str

def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
save_path = (
Expand Down Expand Up @@ -135,7 +141,12 @@ def format_lines(lines: List[str], start_idx: int) -> str:


def view_lines(
lines: List[str], line_num: int, window_size: int, name: str, total_lines: int
lines: List[str],
line_num: int,
window_size: int,
name: str,
total_lines: int,
print_output: bool = True,
) -> str:
start = max(0, line_num - window_size)
end = min(len(lines), line_num + window_size)
Expand All @@ -148,7 +159,9 @@ def view_lines(
else f"[{len(lines) - end} more lines]"
)
)
print(return_str)

if print_output:
print(return_str)
return return_str


Expand Down Expand Up @@ -231,7 +244,7 @@ def edit_code_artifact(
new_content_lines = [
line if line.endswith("\n") else line + "\n" for line in new_content_lines
]
lines = artifacts[name].splitlines()
lines = artifacts[name].splitlines(keepends=True)
edited_lines = lines[:start] + new_content_lines + lines[end:]

cur_line = start + len(content.split("\n")) // 2
Expand Down Expand Up @@ -261,13 +274,20 @@ def edit_code_artifact(
DEFAULT_WINDOW_SIZE,
name,
total_lines,
print_output=False,
)
total_lines_edit = sum(1 for _ in edited_lines)
edited_view = view_lines(
edited_lines, cur_line, DEFAULT_WINDOW_SIZE, name, total_lines_edit
edited_lines,
cur_line,
DEFAULT_WINDOW_SIZE,
name,
total_lines_edit,
print_output=False,
)

error_msg += f"\n[This is how your edit would have looked like if applied]\n{edited_view}\n\n[This is the original code before your edit]\n{original_view}"
print(error_msg)
return error_msg

artifacts[name] = "".join(edited_lines)
Expand Down Expand Up @@ -390,13 +410,122 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
return f"[Media {Path(local_path).name} saved]"


def list_artifacts(artifacts: Artifacts) -> str:
"""Lists all the artifacts that have been loaded into the artifacts object."""
output_str = artifacts.show()
print(output_str)
return output_str


def get_tool_descriptions() -> str:
"""Returns a description of all the tools that `generate_vision_code` has access to.
Helpful for answering questions about what types of vision tasks you can do with
`generate_vision_code`."""
return TOOL_DESCRIPTIONS


def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str:
"""'florence2_fine_tuning' is a tool that fine-tune florence2 to be able to detect
objects in an image based on a given dataset. It returns the fine tuning job id.

Parameters:
bboxes (List[BboxInput]): A list of BboxInput containing the
image path, labels and bounding boxes.
task (str): The florencev2 fine-tuning task. The options are
'phrase_grounding'.

Returns:
UUID: The fine tuning job id, this id will used to retrieve the fine
tuned model.

Example
-------
>>> fine_tuning_job_id = florencev2_fine_tuning(
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
"phrase_grounding"
)
"""
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
task_type = PromptTask[task.upper()]
fine_tuning_request = [
BboxInputBase64(
image=convert_to_b64(bbox_input.image_path),
filename=Path(bbox_input.image_path).name,
labels=bbox_input.labels,
bboxes=bbox_input.bboxes,
)
for bbox_input in bboxes_input
]
landing_api = LandingPublicAPI()
fine_tune_id = str(
landing_api.launch_fine_tuning_job("florencev2", task_type, fine_tuning_request)
)
print(f"[Florence2 fine tuning id: {fine_tune_id}]")
return fine_tune_id


def get_diff(before: str, after: str) -> str:
return "".join(
difflib.unified_diff(
before.splitlines(keepends=True), after.splitlines(keepends=True)
)
)


def use_florence2_fine_tuning(
artifacts: Artifacts, name: str, task: str, fine_tune_id: str
) -> str:
"""Replaces florence2 calls with the fine tuning id. This ensures that the code
utilizes the fined tuned florence2 model. Returns the diff between the original
code and the new code.

Parameters:
artifacts (Artifacts): The artifacts object to edit the code from.
name (str): The name of the artifact to edit.
task (str): The task to fine tune the model for. The options are
'phrase_grounding'.
fine_tune_id (str): The fine tuning job id.

Examples
--------
>>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf")
"""

task_to_fn = {"phrase_grounding": "florence2_phrase_grounding"}

if name not in artifacts:
output_str = f"[Artifact {name} does not exist]"
print(output_str)
return output_str

code = artifacts[name]
if task.lower() == "phrase_grounding":
pattern = r"florence2_phrase_grounding\(\s*([^\)]+)\)"

def replacer(match: re.Match) -> str:
arg = match.group(1) # capture all initial arguments
return f'florence2_phrase_grounding({arg}, "{fine_tune_id}")'

else:
raise ValueError(f"Task {task} is not supported.")

new_code = re.sub(pattern, replacer, code)

if new_code == code:
output_str = (
f"[Fine tuning task {task} function {task_to_fn[task]} not found in code]"
)
print(output_str)
return output_str

artifacts[name] = new_code

diff = get_diff(code, new_code)
print(diff)
return diff


META_TOOL_DOCSTRING = get_tool_documentation(
[
get_tool_descriptions,
Expand All @@ -406,5 +535,8 @@ def get_tool_descriptions() -> str:
generate_vision_code,
edit_vision_code,
write_media_artifact,
florence2_fine_tuning,
use_florence2_fine_tuning,
list_artifacts,
]
)
Loading
Loading