Skip to content

Commit

Permalink
feat: check status and run prediction with fine tuned model (#198)
Browse files Browse the repository at this point in the history
* check status and run prediction with fine tuned model

* fine-tuning to tools

* raise exception when model is not ready
  • Loading branch information
Dayof authored Aug 26, 2024
1 parent 5811595 commit 50e97d1
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 89 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ line_length = 88
profile = "black"

[tool.mypy]
plugins = "pydantic.mypy"

exclude = "tests"
show_error_context = true
pretty = true
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DefaultImports:
code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions, florencev2_fine_tuning",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions",
]

@staticmethod
Expand Down
18 changes: 15 additions & 3 deletions vision_agent/clients/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from requests import Session
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,9 +37,22 @@ def post(self, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
response.raise_for_status()
result: Dict[str, Any] = response.json()
_LOGGER.info(json.dumps(result))
except (ConnectionError, Timeout, RequestException) as err:
_LOGGER.warning(f"Error: {err}.")
except json.JSONDecodeError:
resp_text = response.text
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.")
raise
return result

def get(self, url: str) -> Dict[str, Any]:
formatted_url = f"{self._base_endpoint}/{url}"
_LOGGER.info(f"Sending data to {formatted_url}")
try:
response = self._session.get(url=formatted_url, timeout=self._TIMEOUT)
response.raise_for_status()
result: Dict[str, Any] = response.json()
_LOGGER.info(json.dumps(result))
except json.JSONDecodeError:
resp_text = response.text
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.")
raise
return result
14 changes: 13 additions & 1 deletion vision_agent/clients/landing_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from uuid import UUID
from typing import List

from requests.exceptions import HTTPError

from vision_agent.clients.http import BaseHTTP
from vision_agent.utils.type_defs import LandingaiAPIKey
from vision_agent.tools.meta_tools_types import BboxInputBase64, PromptTask
from vision_agent.utils.exceptions import FineTuneModelNotFound
from vision_agent.tools.tools_types import BboxInputBase64, PromptTask, JobStatus


class LandingPublicAPI(BaseHTTP):
Expand All @@ -24,3 +27,12 @@ def launch_fine_tuning_job(
}
response = self.post(url, payload=data)
return UUID(response["jobId"])

def check_fine_tuning_job(self, job_id: UUID) -> JobStatus:
url = f"v1/agent/jobs/fine-tuning/{job_id}/status"
try:
get_job = self.get(url)
except HTTPError as err:
if err.response.status_code == 404:
raise FineTuneModelNotFound()
return JobStatus(get_job["status"])
4 changes: 3 additions & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Callable, List, Optional

from .meta_tools import META_TOOL_DOCSTRING, florencev2_fine_tuning
from .meta_tools import (
META_TOOL_DOCSTRING,
)
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import (
TOOL_DESCRIPTIONS,
Expand Down
48 changes: 2 additions & 46 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import os
import subprocess
from uuid import UUID
from pathlib import Path
from typing import Any, Dict, List, Union

import vision_agent as va
from vision_agent.lmm.types import Message
from vision_agent.tools.tool_utils import get_tool_documentation
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
from vision_agent.utils.image_utils import convert_to_b64
from vision_agent.clients.landing_public_api import LandingPublicAPI
from vision_agent.tools.meta_tools_types import BboxInput, BboxInputBase64, PromptTask


# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent

Expand Down Expand Up @@ -384,51 +381,11 @@ def edit_file(file_path: str, start: int, end: int, content: str) -> str:

def get_tool_descriptions() -> str:
"""Returns a description of all the tools that `generate_vision_code` has access to.
Helpful for answerings questions about what types of vision tasks you can do with
Helpful for answering questions about what types of vision tasks you can do with
`generate_vision_code`."""
return TOOL_DESCRIPTIONS


def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
"""'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able
to detect objects in an image based on a given dataset. It returns the fine
tuning job id.
Parameters:
bboxes (List[BboxInput]): A list of BboxInput containing the
image path, labels and bounding boxes.
task (PromptTask): The florencev2 fine-tuning task. The options are
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
Returns:
UUID: The fine tuning job id, this id will used to retrieve the fine
tuned model.
Example
-------
>>> fine_tuning_job_id = florencev2_fine_tuning(
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
"OBJECT_DETECTION"
)
"""
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
task_input = PromptTask[task]
fine_tuning_request = [
BboxInputBase64(
image=convert_to_b64(bbox_input.image_path),
filename=bbox_input.image_path.split("/")[-1],
labels=bbox_input.labels,
bboxes=bbox_input.bboxes,
)
for bbox_input in bboxes_input
]
landing_api = LandingPublicAPI()
return landing_api.launch_fine_tuning_job(
"florencev2", task_input, fine_tuning_request
)


META_TOOL_DOCSTRING = get_tool_documentation(
[
get_tool_descriptions,
Expand All @@ -442,6 +399,5 @@ def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
search_dir,
search_file,
find_file,
florencev2_fine_tuning,
]
)
30 changes: 0 additions & 30 deletions vision_agent/tools/meta_tools_types.py

This file was deleted.

24 changes: 17 additions & 7 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from vision_agent.utils.type_defs import LandingaiAPIKey

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
_LND_API_URL = "https://api.landing.ai/v1/agent/model"
_LND_API_URL_v2 = "https://api.landing.ai/v1/tools"
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
_LND_BASE_URL = os.environ.get("LANDINGAI_URL", "https://api.landing.ai")
_LND_API_URL = f"{_LND_BASE_URL}/v1/agent/model"
_LND_API_URL_v2 = f"{_LND_BASE_URL}/v1/tools"


class ToolCallTrace(BaseModel):
Expand All @@ -28,8 +29,13 @@ class ToolCallTrace(BaseModel):


def send_inference_request(
payload: Dict[str, Any], endpoint_name: str, v2: bool = False
payload: Dict[str, Any],
endpoint_name: str,
v2: bool = False,
metadata_payload: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# TODO: runtime_tag and function_name should be metadata_payload and now included
# in the service payload
try:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag
Expand Down Expand Up @@ -62,9 +68,13 @@ def send_inference_request(
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {res.status_code} {res.text}")
raise RemoteToolCallFailed(
payload["function_name"], res.status_code, res.text
)
# TODO: function_name should be in metadata_payload
function_name = "unknown"
if "function_name" in payload:
function_name = payload["function_name"]
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]
raise RemoteToolCallFailed(function_name, res.status_code, res.text)

resp = res.json()
tool_call_trace.response = resp
Expand Down
124 changes: 124 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import tempfile
from uuid import UUID
from pathlib import Path
from importlib import resources
from typing import Any, Dict, List, Optional, Tuple, Union, cast
Expand All @@ -21,6 +22,7 @@
get_tools_df,
get_tools_info,
)
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.execute import FileSerializer, MimeType
from vision_agent.utils.image_utils import (
Expand All @@ -32,6 +34,15 @@
convert_quad_box_to_bbox,
rle_decode,
)
from vision_agent.tools.tools_types import (
BboxInput,
BboxInputBase64,
PromptTask,
Florencev2FtRequest,
FineTuning,
JobStatus,
)
from vision_agent.clients.landing_public_api import LandingPublicAPI

register_heif_opener()

Expand Down Expand Up @@ -1286,6 +1297,119 @@ def overlay_heat_map(
return np.array(combined)


# TODO: add this function to the imports so that is picked in the agent
def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
"""'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able
to detect objects in an image based on a given dataset. It returns the fine
tuning job id.
Parameters:
bboxes (List[BboxInput]): A list of BboxInput containing the
image path, labels and bounding boxes.
task (PromptTask): The florencev2 fine-tuning task. The options are
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
Returns:
UUID: The fine tuning job id, this id will used to retrieve the fine
tuned model.
Example
-------
>>> fine_tuning_job_id = florencev2_fine_tuning(
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
"OBJECT_DETECTION"
)
"""
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
task_input = PromptTask[task]
fine_tuning_request = [
BboxInputBase64(
image=convert_to_b64(bbox_input.image_path),
filename=bbox_input.image_path.split("/")[-1],
labels=bbox_input.labels,
bboxes=bbox_input.bboxes,
)
for bbox_input in bboxes_input
]
landing_api = LandingPublicAPI()
return landing_api.launch_fine_tuning_job(
"florencev2", task_input, fine_tuning_request
)


# TODO: add this function to the imports so that is picked in the agent
def florencev2_fine_tuned_object_detection(
image: np.ndarray, prompt: str, model_id: UUID, task: str
) -> List[Dict[str, Any]]:
"""'florencev2_fine_tuned_object_detection' is a tool that uses a fine tuned model
to detect objects given a text prompt such as a phrase or class names separated by
commas. It returns a list of detected objects as labels and their location as
bounding boxes with score of 1.0.
Parameters:
image (np.ndarray): The image to used to detect objects.
prompt (str): The prompt to help find objects in the image.
model_id (UUID): The fine-tuned model id.
task (PromptTask): The florencev2 fine-tuning task. The options are
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates between 0
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
top-left and xmax and ymax are the coordinates of the bottom-right of the
bounding box. The scores are always 1.0 and cannot be thresholded
Example
-------
>>> florencev2_fine_tuned_object_detection(
image,
'person looking at a coyote',
UUID("381cd5f9-5dc4-472d-9260-f3bb89d31f83")
)
[
{'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
]
"""
# check if job succeeded first
landing_api = LandingPublicAPI()
status = landing_api.check_fine_tuning_job(model_id)
if status is not JobStatus.SUCCEEDED:
raise FineTuneModelIsNotReady()

task = PromptTask[task]
if task is PromptTask.OBJECT_DETECTION:
prompt = ""

data_obj = Florencev2FtRequest(
image=convert_to_b64(image),
task=task,
tool="florencev2_fine_tuning",
prompt=prompt,
fine_tuning=FineTuning(job_id=model_id),
)
data = data_obj.model_dump(by_alias=True)
metadata_payload = {"function_name": "florencev2_fine_tuned_object_detection"}
detections = send_inference_request(
data, "tools", v2=False, metadata_payload=metadata_payload
)

detections = detections[task.value]
return_data = []
image_size = image.shape[:2]
for i in range(len(detections["bboxes"])):
return_data.append(
{
"score": 1.0,
"label": detections["labels"][i],
"bbox": normalize_bbox(detections["bboxes"][i], image_size),
}
)
return return_data


TOOLS = [
owl_v2,
grounding_sam,
Expand Down
Loading

0 comments on commit 50e97d1

Please sign in to comment.