Skip to content

Commit

Permalink
check status and run prediction with fine tuned model
Browse files Browse the repository at this point in the history
  • Loading branch information
Dayof committed Aug 13, 2024
1 parent 62b6137 commit fd87fa5
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 16 deletions.
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DefaultImports:
code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions, florencev2_fine_tuning",
"from vision_agent.tools.meta_tools import generate_vision_code, edit_vision_code, open_file, create_file, scroll_up, scroll_down, edit_file, get_tool_descriptions, florencev2_fine_tuning, florencev2_fine_tuned_object_detection, check_if_fine_tuned_florencev2_is_ready",
]

@staticmethod
Expand Down
15 changes: 15 additions & 0 deletions vision_agent/clients/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,18 @@ def post(self, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
resp_text = response.text
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.")
return result

def get(self, url: str) -> Dict[str, Any]:
formatted_url = f"{self._base_endpoint}/{url}"
_LOGGER.info(f"Sending data to {formatted_url}")
try:
response = self._session.get(url=formatted_url, timeout=self._TIMEOUT)
response.raise_for_status()
result: Dict[str, Any] = response.json()
_LOGGER.info(json.dumps(result))
except (ConnectionError, Timeout, RequestException) as err:
_LOGGER.warning(f"Error: {err}.")
except json.JSONDecodeError:
resp_text = response.text
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.")
return result
6 changes: 5 additions & 1 deletion vision_agent/clients/landing_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from vision_agent.clients.http import BaseHTTP
from vision_agent.utils.type_defs import LandingaiAPIKey
from vision_agent.tools.meta_tools_types import BboxInputBase64, PromptTask
from vision_agent.tools.meta_tools_types import BboxInputBase64, PromptTask, JobStatus


class LandingPublicAPI(BaseHTTP):
Expand All @@ -24,3 +24,7 @@ def launch_fine_tuning_job(
}
response = self.post(url, payload=data)
return UUID(response["jobId"])

def check_fine_tuning_job(self, job_id: UUID) -> JobStatus:
url = f"v1/agent/jobs/fine-tuning/{job_id}/status"
return JobStatus(self.get(url)["status"])
7 changes: 6 additions & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Callable, List, Optional

from .meta_tools import META_TOOL_DOCSTRING, florencev2_fine_tuning
from .meta_tools import (
META_TOOL_DOCSTRING,
florencev2_fine_tuning,
florencev2_fine_tuned_object_detection,
check_if_fine_tuned_florencev2_is_ready,
)
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import (
TOOL_DESCRIPTIONS,
Expand Down
113 changes: 109 additions & 4 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
from pathlib import Path
from typing import Any, Dict, List, Union

import numpy as np

import vision_agent as va
from vision_agent.lmm.types import Message
from vision_agent.tools.tool_utils import get_tool_documentation
from vision_agent.tools.tool_utils import get_tool_documentation, send_inference_request
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
from vision_agent.utils.image_utils import convert_to_b64
from vision_agent.utils.image_utils import convert_to_b64, normalize_bbox
from vision_agent.clients.landing_public_api import LandingPublicAPI
from vision_agent.tools.meta_tools_types import BboxInput, BboxInputBase64, PromptTask
from vision_agent.tools.meta_tools_types import (
BboxInput,
BboxInputBase64,
PromptTask,
Florencev2FtRequest,
FineTuning,
JobStatus,
)

# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent

Expand Down Expand Up @@ -384,7 +393,7 @@ def edit_file(file_path: str, start: int, end: int, content: str) -> str:

def get_tool_descriptions() -> str:
"""Returns a description of all the tools that `generate_vision_code` has access to.
Helpful for answerings questions about what types of vision tasks you can do with
Helpful for answering questions about what types of vision tasks you can do with
`generate_vision_code`."""
return TOOL_DESCRIPTIONS

Expand Down Expand Up @@ -429,6 +438,100 @@ def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
)


def check_if_fine_tuned_florencev2_is_ready(model_id: UUID) -> bool:
"""'check_if_fine_tuned_florencev2_is_ready' is a tool that checks whether
is possible to use a certain florencev2 model. It checks if the status
is SUCCEEDED.
Parameters:
model_id (UUID): The fine-tuned model id.
Returns:
bool: The indication if the model is ready to be used or not. If this
is False, it's recommended to wait 5 seconds before checking again.
Example
-------
>>> check_if_fine_tuned_florencev2_is_ready(UUID("381cd5f9-5dc4-472d-9260-f3bb89d31f83"))
True
"""
# check if job succeeded first
landing_api = LandingPublicAPI()
status = landing_api.check_fine_tuning_job(model_id)
return status is JobStatus.SUCCEEDED


def florencev2_fine_tuned_object_detection(
image: np.ndarray, prompt: str, model_id: UUID, task: str, model_is_ready: bool
) -> List[Dict[str, Any]]:
"""'florencev2_fine_tuned_object_detection' is a tool that uses a fine tuned model
to detect objects given a text prompt such as a phrase or class names separated by
commas. It returns a list of detected objects as labels and their location as
bounding boxes with score of 1.0.
Parameters:
image (np.ndarray): The image to used to detect objects.
prompt (str): The prompt to help find objects in the image.
model_id (UUID): The fine-tuned model id.
task (PromptTask): The florencev2 fine-tuning task. The options are
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
model_is_ready (bool): If the model is ready to be used. It's recommended
to get this value from the function check_if_fine_tuned_florencev2_is_ready.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates between 0
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
top-left and xmax and ymax are the coordinates of the bottom-right of the
bounding box. The scores are always 1.0 and cannot be thresholded
Example
-------
>>> florencev2_fine_tuned_object_detection(
image,
'person looking at a coyote',
UUID("381cd5f9-5dc4-472d-9260-f3bb89d31f83"),
model_is_ready=True
)
[
{'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
]
"""
if not model_is_ready:
return []

task = PromptTask[task]
if task is PromptTask.OBJECT_DETECTION:
prompt = ""

data_obj = Florencev2FtRequest(
image=convert_to_b64(image),
task=task,
tool="florencev2_fine_tuning",
prompt=prompt,
fine_tuning=FineTuning(job_id=model_id),
)
data = data_obj.model_dump(by_alias=True)
metadata_payload = {"function_name": "florencev2_fine_tuned_object_detection"}
detections = send_inference_request(
data, "tools", v2=False, metadata_payload=metadata_payload
)

detections = detections[task.value]
return_data = []
image_size = image.shape[:2]
for i in range(len(detections["bboxes"])):
return_data.append(
{
"score": 1.0,
"label": detections["labels"][i],
"bbox": normalize_bbox(detections["bboxes"][i], image_size),
}
)
return return_data


META_TOOL_DOCSTRING = get_tool_documentation(
[
get_tool_descriptions,
Expand All @@ -443,5 +546,7 @@ def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
search_file,
find_file,
florencev2_fine_tuning,
florencev2_fine_tuned_object_detection,
check_if_fine_tuned_florencev2_is_ready,
]
)
58 changes: 56 additions & 2 deletions vision_agent/tools/meta_tools_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from uuid import UUID
from enum import Enum
from typing import List, Tuple
from typing import List, Tuple, Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field, field_serializer


class BboxInput(BaseModel):
Expand All @@ -28,3 +29,56 @@ class PromptTask(str, Enum):
""""""
OBJECT_DETECTION = "<OD>"
""""""


class FineTuning(BaseModel):
model_config = ConfigDict(populate_by_name=True)

job_id: UUID = Field(alias="jobId")

@field_serializer("job_id")
def serialize_job_id(self, job_id: UUID, _info):
return str(job_id)


class Florencev2FtRequest(BaseModel):
model_config = ConfigDict(populate_by_name=True)

image: str
task: PromptTask
tool: str
prompt: Optional[str] = ""
fine_tuning: Optional[FineTuning] = Field(None, alias="fineTuning")


class JobStatus(str, Enum):
"""The status of a fine-tuning job.
CREATED:
The job has been created and is waiting to be scheduled to run.
STARTING:
The job has started running, but not entering the training phase.
TRAINING:
The job is training a model.
EVALUATING:
The job is evaluating the model and computing metrics.
PUBLISHING:
The job is exporting the artifact(s) to an external directory (s3 or local).
SUCCEEDED:
The job has finished, including training, evaluation and publishing the
artifact(s).
FAILED:
The job has failed for some reason internally, it can be due to resources
issues or the code itself.
STOPPED:
The job has been stopped by the use locally or in the cloud.
"""

CREATED = "CREATED"
STARTING = "STARTING"
TRAINING = "TRAINING"
EVALUATING = "EVALUATING"
PUBLISHING = "PUBLISHING"
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
STOPPED = "STOPPED"
24 changes: 17 additions & 7 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from vision_agent.utils.type_defs import LandingaiAPIKey

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
_LND_API_URL = "https://api.landing.ai/v1/agent/model"
_LND_API_URL_v2 = "https://api.landing.ai/v1/tools"
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
_LND_BASE_URL = os.environ.get("LANDINGAI_URL", "https://api.landing.ai")
_LND_API_URL = f"{_LND_BASE_URL}/v1/agent/model"
_LND_API_URL_v2 = f"{_LND_BASE_URL}/v1/tools"


class ToolCallTrace(BaseModel):
Expand All @@ -28,8 +29,13 @@ class ToolCallTrace(BaseModel):


def send_inference_request(
payload: Dict[str, Any], endpoint_name: str, v2: bool = False
payload: Dict[str, Any],
endpoint_name: str,
v2: bool = False,
metadata_payload: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# TODO: runtime_tag and function_name should be metadata_payload and now included
# in the service payload
try:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag
Expand Down Expand Up @@ -62,9 +68,13 @@ def send_inference_request(
traceback_raw=[],
)
_LOGGER.error(f"Request failed: {res.status_code} {res.text}")
raise RemoteToolCallFailed(
payload["function_name"], res.status_code, res.text
)
# TODO: function_name should be in metadata_payload
function_name = "unknown"
if "function_name" in payload:
function_name = payload["function_name"]
elif metadata_payload is not None and "function_name" in metadata_payload:
function_name = metadata_payload["function_name"]
raise RemoteToolCallFailed(function_name, res.status_code, res.text)

resp = res.json()
tool_call_trace.response = resp
Expand Down

0 comments on commit fd87fa5

Please sign in to comment.