Skip to content

Commit

Permalink
add fine tuning arg to florence2
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 3, 2024
1 parent b9e7541 commit d0bf79e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
14 changes: 10 additions & 4 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,9 @@ def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> s
return answer[task] # type: ignore


def florence2_phrase_grounding(prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None) -> List[Dict[str, Any]]:
def florence2_phrase_grounding(
prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""'florence2_phrase_grounding' is a tool that can detect multiple
objects given a text prompt which can be object names or caption. You
can optionally separate the object names in the text with commas. It returns a list
Expand All @@ -772,6 +774,8 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray, fine_tune_id: Opt
Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to used to detect objects
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
fine-tuned model ID here to use it.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
Expand All @@ -795,14 +799,16 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray, fine_tune_id: Opt
landing_api = LandingPublicAPI()
status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
if status is not JobStatus.SUCCEEDED:
raise FineTuneModelIsNotReady(f"Fine-tuned model {fine_tune_id} is not ready yet")
raise FineTuneModelIsNotReady(
f"Fine-tuned model {fine_tune_id} is not ready yet"
)

data_obj = Florence2FtRequest(
image=image_b64,
task=PromptTask.PHRASE_GROUNDING,
tool="florence2_fine_tuning",
tool="florencev2_fine_tuning",
prompt=prompt,
fine_tuning=FineTuning(job_id=UUID(fine_tune_id))
fine_tuning=FineTuning(job_id=UUID(fine_tune_id)),
)
data = data_obj.model_dump(by_alias=True)
detections = send_inference_request(data, "tools", v2=False)
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/tools_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class BboxInputBase64(BaseModel):

class PromptTask(str, Enum):
"""Valid task prompts options for the Florence2 model."""

PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"


Expand Down

0 comments on commit d0bf79e

Please sign in to comment.