Skip to content

Commit

Permalink
move get_diff and add use_florence2_fine_tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 3, 2024
1 parent 4f32079 commit b9e7541
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 15 deletions.
10 changes: 1 addition & 9 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import difflib
import logging
import os
import sys
Expand Down Expand Up @@ -29,6 +28,7 @@
USER_REQ,
)
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
from vision_agent.tools.meta_tools import get_diff
from vision_agent.utils import CodeInterpreterFactory, Execution
from vision_agent.utils.execute import CodeInterpreter
from vision_agent.utils.image_utils import b64_to_pil
Expand Down Expand Up @@ -63,14 +63,6 @@ def prepend_imports(code: str) -> str:
return DefaultImports.to_code_string() + "\n\n" + code


def get_diff(before: str, after: str) -> str:
return "".join(
difflib.unified_diff(
before.splitlines(keepends=True), after.splitlines(keepends=True)
)
)


def format_memory(memory: List[Dict[str, str]]) -> str:
output_str = ""
for i, m in enumerate(memory):
Expand Down
71 changes: 65 additions & 6 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import difflib
import os
import pickle as pkl
import re
import subprocess
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -394,14 +396,21 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
return f"[Media {Path(local_path).name} saved]"


def list_artifacts(artifacts: Artifacts) -> str:
"""Lists all the artifacts that have been loaded into the artifacts object."""
output_str = artifacts.show()
print(output_str)
return output_str


def get_tool_descriptions() -> str:
"""Returns a description of all the tools that `generate_vision_code` has access to.
Helpful for answering questions about what types of vision tasks you can do with
`generate_vision_code`."""
return TOOL_DESCRIPTIONS


def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str:
"""'florence2_fine_tuning' is a tool that fine-tune florence2 to be able to detect
objects in an image based on a given dataset. It returns the fine tuning job id.
Expand All @@ -420,26 +429,73 @@ def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
>>> fine_tuning_job_id = florencev2_fine_tuning(
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
"OBJECT_DETECTION"
"phrase_grounding"
)
"""
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
task_type = PromptTask(task.upper())
task_type = PromptTask[task.upper()]
fine_tuning_request = [
BboxInputBase64(
image=convert_to_b64(bbox_input.image_path),
filename=bbox_input.image_path.split("/")[-1],
filename=Path(bbox_input.image_path).name,
labels=bbox_input.labels,
bboxes=bbox_input.bboxes,
)
for bbox_input in bboxes_input
]
landing_api = LandingPublicAPI()
return landing_api.launch_fine_tuning_job(
"florencev2", task_type, fine_tuning_request
# fine_tune_id = str(landing_api.launch_fine_tuning_job(
# "florencev2", task_type, fine_tuning_request
# ))
fine_tune_id = "23b3b022-5ebf-4798-9373-20ef36429abf"
print(f"[Florence2 fine tuning id: {fine_tune_id}]")
return fine_tune_id


def get_diff(before: str, after: str) -> str:
return "".join(
difflib.unified_diff(
before.splitlines(keepends=True), after.splitlines(keepends=True)
)
)


def use_florence2_fine_tuning(
artifacts: Artifacts, name: str, task: str, fine_tune_id: str
) -> str:
"""Replaces florence2 calls with the fine tuning id. This ensures that the code
utilizes the fined tuned florence2 model. Returns the diff between the original
code and the new code.
Parameters:
artifacts (Artifacts): The artifacts object to edit the code from.
name (str): The name of the artifact to edit.
task (str): The task to fine tune the model for. The options are
'phrase_grounding'.
fine_tune_id (str): The fine tuning job id.
Examples
--------
>>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf")
"""
code = artifacts[name]
if task.lower() == "phrase_grounding":
pattern = r'florence2_phrase_grounding\((".*?", .*?)\)'

def replacer(match):
return f'florence2_phrase_grounding({match.group(1)}, "{fine_tune_id}")'

else:
raise ValueError(f"Task {task} is not supported.")

new_code = re.sub(pattern, replacer, code)
artifacts[name] = new_code

diff = get_diff(code, new_code)
print(diff)
return diff


META_TOOL_DOCSTRING = get_tool_documentation(
[
get_tool_descriptions,
Expand All @@ -449,5 +505,8 @@ def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
generate_vision_code,
edit_vision_code,
write_media_artifact,
florence2_fine_tuning,
use_florence2_fine_tuning,
list_artifacts,
]
)

0 comments on commit b9e7541

Please sign in to comment.