Skip to content

Commit

Permalink
Add failure retry for model inference requests (#149)
Browse files Browse the repository at this point in the history
* Add failure retry for model inference requests
  • Loading branch information
humpydonkey authored Jun 21, 2024
1 parent 9e87cc1 commit d05a667
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
unit_test:
name: Linting & Unit Test
name: Test
strategy:
matrix:
python-version: [3.9, 3.11]
Expand Down
43 changes: 38 additions & 5 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import Any, Dict

import requests
from requests import Session
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from vision_agent.utils.type_defs import LandingaiAPIKey

Expand All @@ -11,20 +14,50 @@
_LND_API_URL = "https://api.staging.landing.ai/v1/agent"


def _send_inference_request(
def send_inference_request(
payload: Dict[str, Any], endpoint_name: str
) -> Dict[str, Any]:
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
payload["runtime_tag"] = runtime_tag
res = requests.post(
f"{_LND_API_URL}/model/{endpoint_name}",
url = f"{_LND_API_URL}/model/{endpoint_name}"
session = _create_requests_session(
url=url,
num_retry=3,
headers={
"Content-Type": "application/json",
"apikey": _LND_API_KEY,
},
)
res = requests.post(
f"{_LND_API_URL}/model/{endpoint_name}",
json=payload,
)
res = session.post(url, json=payload)
if res.status_code != 200:
_LOGGER.error(f"Request failed: {res.text}")
raise ValueError(f"Request failed: {res.text}")
_LOGGER.error(f"Request failed: {res.status_code} {res.text}")
raise ValueError(f"Request failed: {res.status_code} {res.text}")
return res.json()["data"] # type: ignore


def _create_requests_session(
url: str, num_retry: int, headers: Dict[str, str]
) -> Session:
"""Create a requests session with retry"""
session = Session()
retries = Retry(
total=num_retry,
backoff_factor=2,
raise_on_redirect=True,
raise_on_status=False,
allowed_methods=["GET", "POST", "PUT"],
status_forcelist=[
408, # Request Timeout
429, # Too Many Requests (ie. rate limiter).
502, # Bad Gateway
503, # Service Unavailable (include cloud circuit breaker)
504, # Gateway Timeout
],
)
session.mount(url, HTTPAdapter(max_retries=retries if num_retry > 0 else 0))
session.headers.update(headers)
return session
22 changes: 11 additions & 11 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from PIL import Image, ImageDraw, ImageFont
from pillow_heif import register_heif_opener # type: ignore

from vision_agent.tools.tool_utils import _send_inference_request
from vision_agent.tools.tool_utils import send_inference_request
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.execute import FileSerializer, MimeType
from vision_agent.utils.image_utils import (
Expand Down Expand Up @@ -105,7 +105,7 @@ def grounding_dino(
),
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
}
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
data: Dict[str, Any] = send_inference_request(request_data, "tools")
return_data = []
for i in range(len(data["bboxes"])):
return_data.append(
Expand Down Expand Up @@ -161,7 +161,7 @@ def owl_v2(
"tool": "open_vocab_detection",
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
}
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
data: Dict[str, Any] = send_inference_request(request_data, "tools")
return_data = []
for i in range(len(data["bboxes"])):
return_data.append(
Expand Down Expand Up @@ -225,7 +225,7 @@ def grounding_sam(
"tool": "visual_grounding_segment",
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
}
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
data: Dict[str, Any] = send_inference_request(request_data, "tools")
return_data = []
for i in range(len(data["bboxes"])):
return_data.append(
Expand Down Expand Up @@ -341,7 +341,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
"image": image_b64,
"tool": "zero_shot_counting",
}
resp_data = _send_inference_request(data, "tools")
resp_data = send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data

Expand Down Expand Up @@ -376,7 +376,7 @@ def loca_visual_prompt_counting(
"prompt": bbox_str,
"tool": "few_shot_counting",
}
resp_data = _send_inference_request(data, "tools")
resp_data = send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data

Expand Down Expand Up @@ -407,7 +407,7 @@ def git_vqa_v2(prompt: str, image: np.ndarray) -> str:
"tool": "image_question_answering",
}

answer = _send_inference_request(data, "tools")
answer = send_inference_request(data, "tools")
return answer["text"][0] # type: ignore


Expand Down Expand Up @@ -436,7 +436,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
"image": image_b64,
"tool": "closed_set_image_classification",
}
resp_data = _send_inference_request(data, "tools")
resp_data = send_inference_request(data, "tools")
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
return resp_data

Expand All @@ -463,7 +463,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
"image": image_b64,
"tool": "image_classification",
}
resp_data = _send_inference_request(data, "tools")
resp_data = send_inference_request(data, "tools")
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
return resp_data

Expand All @@ -490,7 +490,7 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
"image": image_b64,
"tool": "nsfw_image_classification",
}
resp_data = _send_inference_request(data, "tools")
resp_data = send_inference_request(data, "tools")
resp_data["scores"] = round(resp_data["scores"], 4)
return resp_data

Expand All @@ -517,7 +517,7 @@ def blip_image_caption(image: np.ndarray) -> str:
"tool": "image_captioning",
}

answer = _send_inference_request(data, "tools")
answer = send_inference_request(data, "tools")
return answer["text"][0] # type: ignore


Expand Down

0 comments on commit d05a667

Please sign in to comment.