-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
191 additions
and
44 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import json | ||
import logging | ||
from typing import Any, Dict, Optional | ||
|
||
from requests import Session | ||
from requests.adapters import HTTPAdapter | ||
from requests.exceptions import ConnectionError, RequestException, Timeout | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class BaseHTTP: | ||
_TIMEOUT = 30 # seconds | ||
_MAX_RETRIES = 3 | ||
|
||
def __init__(self, base_endpoint: str, *, headers: Optional[Dict[str, Any]] = None) -> None: | ||
self._headers = headers | ||
if headers is None: | ||
self._headers = { | ||
"Content-Type": "application/json", | ||
} | ||
self._base_endpoint = base_endpoint | ||
self._session = Session() | ||
self._session.headers.update(self._headers) | ||
self._session.mount(self._base_endpoint, HTTPAdapter(max_retries=self._MAX_RETRIES)) | ||
|
||
def post(self, url: str, payload: Dict[str, Any]) -> None: | ||
formatted_url = f"{self._base_endpoint}/{url}" | ||
_LOGGER.info(f"Sending data to {formatted_url}") | ||
try: | ||
response = self._session.post( | ||
url=formatted_url, | ||
json=payload, | ||
timeout=self._TIMEOUT | ||
) | ||
response.raise_for_status() | ||
_LOGGER.info(json.dumps(response.json())) | ||
except (ConnectionError, Timeout, RequestException) as err: | ||
_LOGGER.warning(f"Error: {err}.") | ||
except json.JSONDecodeError: | ||
resp_text = response.text | ||
_LOGGER.warning(f"Response seems incorrect: '{resp_text}'.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
from uuid import UUID | ||
from typing import List | ||
|
||
from vision_agent.clients.http import BaseHTTP | ||
from vision_agent.utils.type_defs import LandingaiAPIKey | ||
from vision_agent.tools.tool_types import BboxInputBase64 | ||
|
||
|
||
class LandingPublicAPI(BaseHTTP): | ||
def __init__(self) -> None: | ||
landing_url = os.environ.get("LANDINGAI_URL", "https://api.dev.landing.ai") | ||
landing_api_key = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key) | ||
headers = {"Content-Type": "application/json", "apikey": landing_api_key} | ||
super().__init__(base_endpoint=landing_url, headers=headers) | ||
|
||
def launch_fine_tuning_job( | ||
self, model_name: str, task: str, bboxes: List[BboxInputBase64] | ||
) -> UUID: | ||
url = "v1/agent/jobs/fine-tuning" | ||
data = { | ||
"model": {"name": model_name, "task": task}, | ||
"bboxes": [bbox.model_dump(by_alias=True) for bbox in bboxes] | ||
} | ||
response = self.post(url, payload=data) | ||
return UUID(response["jobId"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from typing import List, Tuple | ||
|
||
from nptyping import UInt8, NDArray, Shape | ||
from pydantic import BaseModel, ConfigDict | ||
|
||
|
||
class BboxInput(BaseModel): | ||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
|
||
image: NDArray[Shape["Height, Width, 3"], UInt8] | ||
filename: str | ||
labels: List[str] | ||
bboxes: List[Tuple[int, int, int, int]] | ||
|
||
|
||
class BboxInputBase64(BaseModel): | ||
image: str | ||
filename: str | ||
labels: List[str] | ||
bboxes: List[Tuple[int, int, int, int]] |
Oops, something went wrong.