Skip to content

Commit

Permalink
add florence2 fine tune to owl_v2 args
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 11, 2024
1 parent 2472112 commit 847ca41
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def owl_v2_image(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.10,
fine_tune_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""'owl_v2_image' is a tool that can detect and count multiple objects given a text
prompt such as category names or referring expressions on images. The categories in
Expand All @@ -160,6 +161,8 @@ def owl_v2_image(
image (np.ndarray): The image to ground the prompt to.
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.10.
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
fine-tuned model ID here to use it.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
Expand All @@ -176,7 +179,38 @@ def owl_v2_image(
{'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
]
"""

image_size = image.shape[:2]

if fine_tune_id is not None:
image_b64 = convert_to_b64(image)
landing_api = LandingPublicAPI()
status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
if status is not JobStatus.SUCCEEDED:
raise FineTuneModelIsNotReady(
f"Fine-tuned model {fine_tune_id} is not ready yet"
)

data_obj = Florence2FtRequest(
image=image_b64,
task=PromptTask.PHRASE_GROUNDING,
tool="florencev2_fine_tuning",
prompt=prompt,
fine_tuning=FineTuning(job_id=UUID(fine_tune_id)),
)
data = data_obj.model_dump(by_alias=True)
detections = send_inference_request(data, "tools", v2=False)
detections = detections["<CAPTION_TO_PHRASE_GROUNDING>"]
bboxes_formatted = [
ODResponseData(
label=detections["labels"][i],
bbox=normalize_bbox(detections["bboxes"][i], image_size),
score=1.0,
)
for i in range(len(detections["bboxes"]))
]
return [bbox.model_dump() for bbox in bboxes_formatted]

buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
payload = {
Expand Down Expand Up @@ -1119,13 +1153,13 @@ def florence2_phrase_grounding(
return_data = []
for i in range(len(detections["bboxes"])):
return_data.append(
{
"score": 1.0,
"label": detections["labels"][i],
"bbox": normalize_bbox(detections["bboxes"][i], image_size),
}
ODResponseData(
label=detections["labels"][i],
bbox=normalize_bbox(detections["bboxes"][i], image_size),
score=1.0,
)
)
return return_data
return [bbox.model_dump() for bbox in return_data]


def florence2_ocr(image: np.ndarray) -> List[Dict[str, Any]]:
Expand Down

0 comments on commit 847ca41

Please sign in to comment.