Skip to content

Commit

Permalink
add florencev2 fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
Dayof committed Aug 6, 2024
1 parent 66a1478 commit 687203f
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 44 deletions.
Empty file.
42 changes: 42 additions & 0 deletions vision_agent/clients/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json
import logging
from typing import Any, Dict, Optional

from requests import Session
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout

_LOGGER = logging.getLogger(__name__)


class BaseHTTP:
_TIMEOUT = 30 # seconds
_MAX_RETRIES = 3

def __init__(self, base_endpoint: str, *, headers: Optional[Dict[str, Any]] = None) -> None:
self._headers = headers
if headers is None:
self._headers = {
"Content-Type": "application/json",
}
self._base_endpoint = base_endpoint
self._session = Session()
self._session.headers.update(self._headers)
self._session.mount(self._base_endpoint, HTTPAdapter(max_retries=self._MAX_RETRIES))

def post(self, url: str, payload: Dict[str, Any]) -> None:
formatted_url = f"{self._base_endpoint}/{url}"
_LOGGER.info(f"Sending data to {formatted_url}")
try:
response = self._session.post(
url=formatted_url,
json=payload,
timeout=self._TIMEOUT
)
response.raise_for_status()
_LOGGER.info(json.dumps(response.json()))
except (ConnectionError, Timeout, RequestException) as err:
_LOGGER.warning(f"Error: {err}.")
except json.JSONDecodeError:
resp_text = response.text
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.")
26 changes: 26 additions & 0 deletions vision_agent/clients/landing_public_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
from uuid import UUID
from typing import List

from vision_agent.clients.http import BaseHTTP
from vision_agent.utils.type_defs import LandingaiAPIKey
from vision_agent.tools.tool_types import BboxInputBase64


class LandingPublicAPI(BaseHTTP):
def __init__(self) -> None:
landing_url = os.environ.get("LANDINGAI_URL", "https://api.dev.landing.ai")
landing_api_key = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
headers = {"Content-Type": "application/json", "apikey": landing_api_key}
super().__init__(base_endpoint=landing_url, headers=headers)

def launch_fine_tuning_job(
self, model_name: str, task: str, bboxes: List[BboxInputBase64]
) -> UUID:
url = "v1/agent/jobs/fine-tuning"
data = {
"model": {"name": model_name, "task": task},
"bboxes": [bbox.model_dump(by_alias=True) for bbox in bboxes]
}
response = self.post(url, payload=data)
return UUID(response["jobId"])
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
florencev2_image_caption,
florencev2_object_detection,
florencev2_roberta_vqa,
florencev2_fine_tuning,
generate_pose_image,
generate_soft_edge_image,
get_tool_documentation,
Expand Down
20 changes: 20 additions & 0 deletions vision_agent/tools/tool_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List, Tuple

from nptyping import UInt8, NDArray, Shape
from pydantic import BaseModel, ConfigDict


class BboxInput(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

image: NDArray[Shape["Height, Width, 3"], UInt8]
filename: str
labels: List[str]
bboxes: List[Tuple[int, int, int, int]]


class BboxInputBase64(BaseModel):
image: str
filename: str
labels: List[str]
bboxes: List[Tuple[int, int, int, int]]
Loading

0 comments on commit 687203f

Please sign in to comment.