diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 3167d3be..53259e27 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -369,7 +369,9 @@ def grounding_sam( return return_data -def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]: +def florence2_sam2_image( + prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None +) -> List[Dict[str, Any]]: """'florence2_sam2_image' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, @@ -378,6 +380,8 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]] Parameters: prompt (str): The prompt to ground to the image. image (np.ndarray): The image to ground the prompt to. + fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the + fine-tuned model ID here to use it. Returns: List[Dict[str, Any]]: A list of dictionaries containing the score, label, @@ -403,8 +407,42 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]] }, ] """ - buffer_bytes = numpy_to_bytes(image) + if fine_tune_id is not None: + image_b64 = convert_to_b64(image) + landing_api = LandingPublicAPI() + status = landing_api.check_fine_tuning_job(UUID(fine_tune_id)) + if status is not JobStatus.SUCCEEDED: + raise FineTuneModelIsNotReady( + f"Fine-tuned model {fine_tune_id} is not ready yet" + ) + data_obj = Florence2FtRequest( + image=image_b64, + task=PromptTask.PHRASE_GROUNDING, + tool="florencev2_fine_tuning", + prompt=prompt, + fine_tuning=FineTuning( + job_id=UUID(fine_tune_id), + postprocessing="sam2", + ), + ) + data = data_obj.model_dump(by_alias=True) + detections = send_inference_request(data, "tools", v2=False) + detections = detections[""] + return_data = [] + all_masks = np.array(detections["masks"]) + for i in range(len(detections["bboxes"])): + return_data.append( + { + "score": 1.0, + "label": detections["labels"][i], + "bbox": detections["bboxes"][i], + "mask": all_masks[i, :, :].astype(np.uint8), + } + ) + return return_data + + buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] payload = { "prompts": [s.strip() for s in prompt.split(",")], diff --git a/vision_agent/tools/tools_types.py b/vision_agent/tools/tools_types.py index 6ebcf468..aa0e430f 100644 --- a/vision_agent/tools/tools_types.py +++ b/vision_agent/tools/tools_types.py @@ -28,6 +28,7 @@ class FineTuning(BaseModel): model_config = ConfigDict(populate_by_name=True) job_id: UUID = Field(alias="jobId") + postprocessing: Optional[str] = None @field_serializer("job_id") def serialize_job_id(self, job_id: UUID, _info: SerializationInfo) -> str: