Skip to content

Commit

Permalink
add generic OD fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 11, 2024
1 parent b304e48 commit 8ae675c
Showing 1 changed file with 33 additions and 29 deletions.
62 changes: 33 additions & 29 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,29 +471,29 @@ def get_tool_descriptions() -> str:
return TOOL_DESCRIPTIONS


def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str:
def object_detection_fine_tuning(bboxes: List[Dict[str, Any]]) -> str:
"""DO NOT use this function unless the user has supplied you with bboxes.
'florence2_fine_tuning' is a tool that fine-tune florence2 to be able to detect
objects in an image based on a given dataset. It returns the fine tuning job id.
'object_detection_fine_tuning' is a tool that fine-tunes object detection models to
be able to detect objects in an image based on a given dataset. It returns the fine
tuning job id.
Parameters:
bboxes (List[BboxInput]): A list of BboxInput containing the image path, labels
and bounding boxes. The coordinates are unnormalized.
task (str): The florencev2 fine-tuning task. The options are
'phrase_grounding'.
Returns:
str: The fine tuning job id, this id will used to retrieve the fine tuned
model.
Example
-------
>>> fine_tuning_job_id = florencev2_fine_tuning(
>>> fine_tuning_job_id = object_detection_fine_tuning(
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
"phrase_grounding"
)
"""
task = "phrase_grounding"
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
task_type = PromptTask[task.upper()]
fine_tuning_request = [
Expand Down Expand Up @@ -569,48 +569,52 @@ def edit_replacer(match: re.Match) -> str:
return new_code


def use_florence2_fine_tuning(
artifacts: Artifacts, name: str, task: str, fine_tune_id: str
def use_object_detection_fine_tuning(
artifacts: Artifacts, name: str, fine_tune_id: str
) -> str:
"""Replaces florence2 calls with the fine tuning id. This ensures that the code
utilizes the fined tuned florence2 model. Returns the diff between the original
code and the new code.
"""Replaces calls to 'owl_v2_image', 'florence2_phrase_detection' and
'florence2_sam2_image' with the fine tuning id. This ensures that the code utilizes
the fined tuned florence2 model. Returns the diff between the original code and the
new code.
Parameters:
artifacts (Artifacts): The artifacts object to edit the code from.
name (str): The name of the artifact to edit.
task (str): The task to fine tune the model for. The options are
'phrase_grounding'.
fine_tune_id (str): The fine tuning job id.
Examples
--------
>>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf")
>>> diff = use_object_detection_fine_tuning(artifacts, "code.py", "23b3b022-5ebf-4798-9373-20ef36429abf")
"""

task_to_fn = {"phrase_grounding": "florence2_phrase_grounding"}

if name not in artifacts:
output_str = f"[Artifact {name} does not exist]"
print(output_str)
return output_str

code = artifacts[name]
if task.lower() == "phrase_grounding":
pattern = r"florence2_phrase_grounding\(\s*([^\)]+)\)"

def replacer(match: re.Match) -> str:
arg = match.group(1) # capture all initial arguments
return f'florence2_phrase_grounding({arg}, "{fine_tune_id}")'

else:
raise ValueError(f"Task {task} is not supported.")
patterns = [
(
r"florence2_phrase_grounding\(\s*([^\)]+)\s*\)",
lambda match: f'florence2_phrase_grounding({match.group(1)}, "{fine_tune_id}")',
),
(
r"owl_v2_image\(\s*([^\)]+)\s*\)",
lambda match: f'owl_v2_image({match.group(1)}, "{fine_tune_id}")',
),
(
r"florence2_sam2_image\(\s*([^\)]+)\s*\)",
lambda match: f'florence2_sam2_image({match.group(1)}, "{fine_tune_id}")',
),
]

new_code = re.sub(pattern, replacer, code)
new_code = code
for pattern, replacer in patterns:
new_code = re.sub(pattern, replacer, new_code)

if new_code == code:
output_str = (
f"[Fine tuning task {task} function {task_to_fn[task]} not found in code]"
f"[No function calls to replace with fine tuning id in artifact {name}]"
)
print(output_str)
return output_str
Expand All @@ -632,8 +636,8 @@ def replacer(match: re.Match) -> str:
edit_vision_code,
write_media_artifact,
view_media_artifact,
florence2_fine_tuning,
use_florence2_fine_tuning,
object_detection_fine_tuning,
use_object_detection_fine_tuning,
list_artifacts,
]
)

0 comments on commit 8ae675c

Please sign in to comment.