From d0bf79e2e4916153fba652187703fc45a8903cec Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 3 Sep 2024 07:52:27 -0700 Subject: [PATCH] add fine tuning arg to florence2 --- vision_agent/tools/tools.py | 14 ++++++++++---- vision_agent/tools/tools_types.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 92a47a99..1ea19f62 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -762,7 +762,9 @@ def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> s return answer[task] # type: ignore -def florence2_phrase_grounding(prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None) -> List[Dict[str, Any]]: +def florence2_phrase_grounding( + prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None +) -> List[Dict[str, Any]]: """'florence2_phrase_grounding' is a tool that can detect multiple objects given a text prompt which can be object names or caption. You can optionally separate the object names in the text with commas. It returns a list @@ -772,6 +774,8 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray, fine_tune_id: Opt Parameters: prompt (str): The prompt to ground to the image. image (np.ndarray): The image to used to detect objects + fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the + fine-tuned model ID here to use it. Returns: List[Dict[str, Any]]: A list of dictionaries containing the score, label, and @@ -795,14 +799,16 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray, fine_tune_id: Opt landing_api = LandingPublicAPI() status = landing_api.check_fine_tuning_job(UUID(fine_tune_id)) if status is not JobStatus.SUCCEEDED: - raise FineTuneModelIsNotReady(f"Fine-tuned model {fine_tune_id} is not ready yet") + raise FineTuneModelIsNotReady( + f"Fine-tuned model {fine_tune_id} is not ready yet" + ) data_obj = Florence2FtRequest( image=image_b64, task=PromptTask.PHRASE_GROUNDING, - tool="florence2_fine_tuning", + tool="florencev2_fine_tuning", prompt=prompt, - fine_tuning=FineTuning(job_id=UUID(fine_tune_id)) + fine_tuning=FineTuning(job_id=UUID(fine_tune_id)), ) data = data_obj.model_dump(by_alias=True) detections = send_inference_request(data, "tools", v2=False) diff --git a/vision_agent/tools/tools_types.py b/vision_agent/tools/tools_types.py index 20d178d7..eb436d94 100644 --- a/vision_agent/tools/tools_types.py +++ b/vision_agent/tools/tools_types.py @@ -20,6 +20,7 @@ class BboxInputBase64(BaseModel): class PromptTask(str, Enum): """Valid task prompts options for the Florence2 model.""" + PHRASE_GROUNDING = ""