From 8ae675cf7fe85007041028380956f5230e16706b Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 11 Sep 2024 14:55:18 -0700 Subject: [PATCH] add generic OD fine tuning --- vision_agent/tools/meta_tools.py | 62 +++++++++++++++++--------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 3ec227f8..2c23856c 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -471,16 +471,15 @@ def get_tool_descriptions() -> str: return TOOL_DESCRIPTIONS -def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str: +def object_detection_fine_tuning(bboxes: List[Dict[str, Any]]) -> str: """DO NOT use this function unless the user has supplied you with bboxes. - 'florence2_fine_tuning' is a tool that fine-tune florence2 to be able to detect - objects in an image based on a given dataset. It returns the fine tuning job id. + 'object_detection_fine_tuning' is a tool that fine-tunes object detection models to + be able to detect objects in an image based on a given dataset. It returns the fine + tuning job id. Parameters: bboxes (List[BboxInput]): A list of BboxInput containing the image path, labels and bounding boxes. The coordinates are unnormalized. - task (str): The florencev2 fine-tuning task. The options are - 'phrase_grounding'. Returns: str: The fine tuning job id, this id will used to retrieve the fine tuned @@ -488,12 +487,13 @@ def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str: Example ------- - >>> fine_tuning_job_id = florencev2_fine_tuning( + >>> fine_tuning_job_id = object_detection_fine_tuning( [{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]}, {'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}], "phrase_grounding" ) """ + task = "phrase_grounding" bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes] task_type = PromptTask[task.upper()] fine_tuning_request = [ @@ -569,48 +569,52 @@ def edit_replacer(match: re.Match) -> str: return new_code -def use_florence2_fine_tuning( - artifacts: Artifacts, name: str, task: str, fine_tune_id: str +def use_object_detection_fine_tuning( + artifacts: Artifacts, name: str, fine_tune_id: str ) -> str: - """Replaces florence2 calls with the fine tuning id. This ensures that the code - utilizes the fined tuned florence2 model. Returns the diff between the original - code and the new code. + """Replaces calls to 'owl_v2_image', 'florence2_phrase_detection' and + 'florence2_sam2_image' with the fine tuning id. This ensures that the code utilizes + the fined tuned florence2 model. Returns the diff between the original code and the + new code. Parameters: artifacts (Artifacts): The artifacts object to edit the code from. name (str): The name of the artifact to edit. - task (str): The task to fine tune the model for. The options are - 'phrase_grounding'. fine_tune_id (str): The fine tuning job id. Examples -------- - >>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf") + >>> diff = use_object_detection_fine_tuning(artifacts, "code.py", "23b3b022-5ebf-4798-9373-20ef36429abf") """ - task_to_fn = {"phrase_grounding": "florence2_phrase_grounding"} - if name not in artifacts: output_str = f"[Artifact {name} does not exist]" print(output_str) return output_str code = artifacts[name] - if task.lower() == "phrase_grounding": - pattern = r"florence2_phrase_grounding\(\s*([^\)]+)\)" - - def replacer(match: re.Match) -> str: - arg = match.group(1) # capture all initial arguments - return f'florence2_phrase_grounding({arg}, "{fine_tune_id}")' - - else: - raise ValueError(f"Task {task} is not supported.") + patterns = [ + ( + r"florence2_phrase_grounding\(\s*([^\)]+)\s*\)", + lambda match: f'florence2_phrase_grounding({match.group(1)}, "{fine_tune_id}")', + ), + ( + r"owl_v2_image\(\s*([^\)]+)\s*\)", + lambda match: f'owl_v2_image({match.group(1)}, "{fine_tune_id}")', + ), + ( + r"florence2_sam2_image\(\s*([^\)]+)\s*\)", + lambda match: f'florence2_sam2_image({match.group(1)}, "{fine_tune_id}")', + ), + ] - new_code = re.sub(pattern, replacer, code) + new_code = code + for pattern, replacer in patterns: + new_code = re.sub(pattern, replacer, new_code) if new_code == code: output_str = ( - f"[Fine tuning task {task} function {task_to_fn[task]} not found in code]" + f"[No function calls to replace with fine tuning id in artifact {name}]" ) print(output_str) return output_str @@ -632,8 +636,8 @@ def replacer(match: re.Match) -> str: edit_vision_code, write_media_artifact, view_media_artifact, - florence2_fine_tuning, - use_florence2_fine_tuning, + object_detection_fine_tuning, + use_object_detection_fine_tuning, list_artifacts, ] )