diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index 85797e4c..d41e7592 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -7,7 +7,7 @@ on: jobs: unit_test: - name: Linting & Unit Test + name: Test strategy: matrix: python-version: [3.9, 3.11] diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index f4555f84..fa248876 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -3,6 +3,9 @@ from typing import Any, Dict import requests +from requests import Session +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry from vision_agent.utils.type_defs import LandingaiAPIKey @@ -11,20 +14,50 @@ _LND_API_URL = "https://api.staging.landing.ai/v1/agent" -def _send_inference_request( +def send_inference_request( payload: Dict[str, Any], endpoint_name: str ) -> Dict[str, Any]: if runtime_tag := os.environ.get("RUNTIME_TAG", ""): payload["runtime_tag"] = runtime_tag - res = requests.post( - f"{_LND_API_URL}/model/{endpoint_name}", + url = f"{_LND_API_URL}/model/{endpoint_name}" + session = _create_requests_session( + url=url, + num_retry=3, headers={ "Content-Type": "application/json", "apikey": _LND_API_KEY, }, + ) + res = requests.post( + f"{_LND_API_URL}/model/{endpoint_name}", json=payload, ) + res = session.post(url, json=payload) if res.status_code != 200: - _LOGGER.error(f"Request failed: {res.text}") - raise ValueError(f"Request failed: {res.text}") + _LOGGER.error(f"Request failed: {res.status_code} {res.text}") + raise ValueError(f"Request failed: {res.status_code} {res.text}") return res.json()["data"] # type: ignore + + +def _create_requests_session( + url: str, num_retry: int, headers: Dict[str, str] +) -> Session: + """Create a requests session with retry""" + session = Session() + retries = Retry( + total=num_retry, + backoff_factor=2, + raise_on_redirect=True, + raise_on_status=False, + allowed_methods=["GET", "POST", "PUT"], + status_forcelist=[ + 408, # Request Timeout + 429, # Too Many Requests (ie. rate limiter). + 502, # Bad Gateway + 503, # Service Unavailable (include cloud circuit breaker) + 504, # Gateway Timeout + ], + ) + session.mount(url, HTTPAdapter(max_retries=retries if num_retry > 0 else 0)) + session.headers.update(headers) + return session diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 1eab2208..8acda4b2 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -15,7 +15,7 @@ from PIL import Image, ImageDraw, ImageFont from pillow_heif import register_heif_opener # type: ignore -from vision_agent.tools.tool_utils import _send_inference_request +from vision_agent.tools.tool_utils import send_inference_request from vision_agent.utils import extract_frames_from_video from vision_agent.utils.execute import FileSerializer, MimeType from vision_agent.utils.image_utils import ( @@ -105,7 +105,7 @@ def grounding_dino( ), "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } - data: Dict[str, Any] = _send_inference_request(request_data, "tools") + data: Dict[str, Any] = send_inference_request(request_data, "tools") return_data = [] for i in range(len(data["bboxes"])): return_data.append( @@ -161,7 +161,7 @@ def owl_v2( "tool": "open_vocab_detection", "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } - data: Dict[str, Any] = _send_inference_request(request_data, "tools") + data: Dict[str, Any] = send_inference_request(request_data, "tools") return_data = [] for i in range(len(data["bboxes"])): return_data.append( @@ -225,7 +225,7 @@ def grounding_sam( "tool": "visual_grounding_segment", "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } - data: Dict[str, Any] = _send_inference_request(request_data, "tools") + data: Dict[str, Any] = send_inference_request(request_data, "tools") return_data = [] for i in range(len(data["bboxes"])): return_data.append( @@ -341,7 +341,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: "image": image_b64, "tool": "zero_shot_counting", } - resp_data = _send_inference_request(data, "tools") + resp_data = send_inference_request(data, "tools") resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) return resp_data @@ -376,7 +376,7 @@ def loca_visual_prompt_counting( "prompt": bbox_str, "tool": "few_shot_counting", } - resp_data = _send_inference_request(data, "tools") + resp_data = send_inference_request(data, "tools") resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) return resp_data @@ -407,7 +407,7 @@ def git_vqa_v2(prompt: str, image: np.ndarray) -> str: "tool": "image_question_answering", } - answer = _send_inference_request(data, "tools") + answer = send_inference_request(data, "tools") return answer["text"][0] # type: ignore @@ -436,7 +436,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: "image": image_b64, "tool": "closed_set_image_classification", } - resp_data = _send_inference_request(data, "tools") + resp_data = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] return resp_data @@ -463,7 +463,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]: "image": image_b64, "tool": "image_classification", } - resp_data = _send_inference_request(data, "tools") + resp_data = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] return resp_data @@ -490,7 +490,7 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: "image": image_b64, "tool": "nsfw_image_classification", } - resp_data = _send_inference_request(data, "tools") + resp_data = send_inference_request(data, "tools") resp_data["scores"] = round(resp_data["scores"], 4) return resp_data @@ -517,7 +517,7 @@ def blip_image_caption(image: np.ndarray) -> str: "tool": "image_captioning", } - answer = _send_inference_request(data, "tools") + answer = send_inference_request(data, "tools") return answer["text"][0] # type: ignore