Skip to content

Commit

Permalink
added fine tune id for florence2sam2
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 11, 2024
1 parent 847ca41 commit b304e48
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
42 changes: 40 additions & 2 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ def grounding_sam(
return return_data


def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
def florence2_sam2_image(
prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""'florence2_sam2_image' is a tool that can segment multiple objects given a text
prompt such as category names or referring expressions. The categories in the text
prompt are separated by commas. It returns a list of bounding boxes, label names,
Expand All @@ -378,6 +380,8 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]
Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to ground the prompt to.
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
fine-tuned model ID here to use it.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
Expand All @@ -403,8 +407,42 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]
},
]
"""
buffer_bytes = numpy_to_bytes(image)
if fine_tune_id is not None:
image_b64 = convert_to_b64(image)
landing_api = LandingPublicAPI()
status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
if status is not JobStatus.SUCCEEDED:
raise FineTuneModelIsNotReady(
f"Fine-tuned model {fine_tune_id} is not ready yet"
)

data_obj = Florence2FtRequest(
image=image_b64,
task=PromptTask.PHRASE_GROUNDING,
tool="florencev2_fine_tuning",
prompt=prompt,
fine_tuning=FineTuning(
job_id=UUID(fine_tune_id),
postprocessing="sam2",
),
)
data = data_obj.model_dump(by_alias=True)
detections = send_inference_request(data, "tools", v2=False)
detections = detections["<CAPTION_TO_PHRASE_GROUNDING>"]
return_data = []
all_masks = np.array(detections["masks"])
for i in range(len(detections["bboxes"])):
return_data.append(
{
"score": 1.0,
"label": detections["labels"][i],
"bbox": detections["bboxes"][i],
"mask": all_masks[i, :, :].astype(np.uint8),
}
)
return return_data

buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
payload = {
"prompts": [s.strip() for s in prompt.split(",")],
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/tools_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class FineTuning(BaseModel):
model_config = ConfigDict(populate_by_name=True)

job_id: UUID = Field(alias="jobId")
postprocessing: Optional[str] = None

@field_serializer("job_id")
def serialize_job_id(self, job_id: UUID, _info: SerializationInfo) -> str:
Expand Down

0 comments on commit b304e48

Please sign in to comment.