Skip to content

Commit 76bacbc

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Use aiohttp in async APIs to lower latency
PiperOrigin-RevId: 766398744
1 parent ae2392c commit 76bacbc

File tree

6 files changed

+211
-139
lines changed

6 files changed

+211
-139
lines changed

google/genai/_api_client.py

Lines changed: 159 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,20 @@
2020
"""
2121

2222
import asyncio
23-
from collections.abc import Awaitable, Generator
23+
from collections.abc import Generator
2424
import copy
2525
from dataclasses import dataclass
26-
import datetime
27-
import http
2826
import io
2927
import json
3028
import logging
3129
import math
30+
from multidict import CIMultiDictProxy
3231
import os
3332
import ssl
3433
import sys
3534
import threading
3635
import time
37-
from typing import Any, AsyncIterator, Optional, Tuple, Union
36+
from typing import Any, AsyncIterator, Optional, Tuple, Union, cast
3837
from urllib.parse import urlparse
3938
from urllib.parse import urlunparse
4039

@@ -44,6 +43,7 @@
4443
import google.auth.credentials
4544
from google.auth.credentials import Credentials
4645
from google.auth.transport.requests import Request
46+
import aiohttp # type: ignore
4747
import httpx
4848
from pydantic import BaseModel
4949
from pydantic import Field
@@ -53,7 +53,6 @@
5353
from . import errors
5454
from . import version
5555
from .types import HttpOptions
56-
from .types import HttpOptionsDict
5756
from .types import HttpOptionsOrDict
5857

5958

@@ -216,7 +215,7 @@ class HttpResponse:
216215

217216
def __init__(
218217
self,
219-
headers: Union[dict[str, str], httpx.Headers],
218+
headers: Union[dict[str, str], httpx.Headers, "CIMultiDictProxy[str]"],
220219
response_stream: Union[Any, str] = None,
221220
byte_stream: Union[Any, bytes] = None,
222221
):
@@ -271,8 +270,8 @@ async def async_segments(self) -> AsyncIterator[Any]:
271270
yield c
272271
else:
273272
# Iterator of objects retrieved from the API.
274-
if hasattr(self.response_stream, 'aiter_lines'):
275-
async for chunk in self.response_stream.aiter_lines():
273+
if hasattr(self.response_stream, 'content'):
274+
async for chunk in self.response_stream.content.iter_any():
276275
# This is httpx.Response.
277276
if chunk:
278277
# In async streaming mode, the chunk of JSON is prefixed with
@@ -324,26 +323,6 @@ def __del__(self) -> None:
324323
pass
325324

326325

327-
class AsyncHttpxClient(httpx.AsyncClient):
328-
"""Async httpx client."""
329-
330-
def __init__(self, **kwargs: Any) -> None:
331-
"""Initializes the httpx client."""
332-
kwargs.setdefault('follow_redirects', True)
333-
super().__init__(**kwargs)
334-
335-
def __del__(self) -> None:
336-
try:
337-
if self.is_closed:
338-
return
339-
except Exception:
340-
pass
341-
try:
342-
asyncio.get_running_loop().create_task(self.aclose())
343-
except Exception:
344-
pass
345-
346-
347326
class BaseApiClient:
348327
"""Client for calling HTTP APIs sending and receiving JSON."""
349328

@@ -480,12 +459,11 @@ def __init__(
480459
if self._http_options.headers is not None:
481460
_append_library_version_headers(self._http_options.headers)
482461

483-
client_args, async_client_args = self._ensure_ssl_ctx(self._http_options)
462+
client_args, _ = self._ensure_httpx_ssl_ctx(self._http_options)
484463
self._httpx_client = SyncHttpxClient(**client_args)
485-
self._async_httpx_client = AsyncHttpxClient(**async_client_args)
486464

487465
@staticmethod
488-
def _ensure_ssl_ctx(options: HttpOptions) -> (
466+
def _ensure_httpx_ssl_ctx(options: HttpOptions) -> (
489467
Tuple[dict[str, Any], dict[str, Any]]):
490468
"""Ensures the SSL context is present in the client args.
491469
@@ -541,6 +519,58 @@ def _maybe_set(
541519
_maybe_set(async_args, ctx),
542520
)
543521

522+
@staticmethod
523+
def _ensure_aiohttp_ssl_ctx(options: HttpOptions) -> dict[str, Any]:
524+
"""Ensures the SSL context is present in the async client args.
525+
526+
Creates a default SSL context if one is not provided.
527+
528+
Args:
529+
options: The http options to check for SSL context.
530+
531+
Returns:
532+
An async aiohttp ClientSession._request args.
533+
"""
534+
535+
verify = 'verify'
536+
async_args = options.async_client_args
537+
ctx = async_args.get(verify) if async_args else None
538+
539+
if not ctx:
540+
# Initialize the SSL context for the httpx client.
541+
# Unlike requests, the aiohttp package does not automatically pull in the
542+
# environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
543+
# enabled explicitly. Instead of 'verify' at client level in httpx,
544+
# aiohttp uses 'ssl' at request level.
545+
ctx = ssl.create_default_context(
546+
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
547+
capath=os.environ.get('SSL_CERT_DIR'),
548+
)
549+
550+
def _maybe_set(
551+
args: Optional[dict[str, Any]],
552+
ctx: ssl.SSLContext,
553+
) -> dict[str, Any]:
554+
"""Sets the SSL context in the client args if not set.
555+
556+
Does not override the SSL context if it is already set.
557+
558+
Args:
559+
args: The client args to to check for SSL context.
560+
ctx: The SSL context to set.
561+
562+
Returns:
563+
The client args with the SSL context included.
564+
"""
565+
if not args or not args.get(verify):
566+
args = (args or {}).copy()
567+
args['ssl'] = ctx
568+
else:
569+
args['ssl'] = args.pop(verify)
570+
return args
571+
572+
return _maybe_set(async_args, ctx)
573+
544574
def _websocket_base_url(self) -> str:
545575
url_parts = urlparse(self._http_options.base_url)
546576
return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
@@ -736,34 +766,35 @@ async def _async_request(
736766
else:
737767
data = http_request.data
738768

769+
async_client_args = self._ensure_aiohttp_ssl_ctx(self._http_options)
739770
if stream:
740-
httpx_request = self._async_httpx_client.build_request(
771+
session = aiohttp.ClientSession()
772+
response = await session.request(
741773
method=http_request.method,
742774
url=http_request.url,
743-
content=data,
744775
headers=http_request.headers,
745-
timeout=http_request.timeout,
746-
)
747-
response = await self._async_httpx_client.send(
748-
httpx_request,
749-
stream=stream,
776+
data=data,
777+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
778+
**async_client_args,
750779
)
751780
await errors.APIError.raise_for_async_response(response)
752781
return HttpResponse(
753-
response.headers, response if stream else [response.text]
782+
response.headers, response
754783
)
755784
else:
756-
response = await self._async_httpx_client.request(
757-
method=http_request.method,
758-
url=http_request.url,
759-
headers=http_request.headers,
760-
content=data,
761-
timeout=http_request.timeout,
762-
)
763-
await errors.APIError.raise_for_async_response(response)
764-
return HttpResponse(
765-
response.headers, response if stream else [response.text]
766-
)
785+
async with aiohttp.ClientSession() as session: #
786+
response = await session.request(
787+
method=http_request.method,
788+
url=http_request.url,
789+
headers=http_request.headers,
790+
data=data,
791+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
792+
**async_client_args,
793+
)
794+
await errors.APIError.raise_for_async_response(response)
795+
return HttpResponse(
796+
response.headers, response if stream else [await response.text()]
797+
)
767798

768799
def get_read_only_http_options(self) -> dict[str, Any]:
769800
if isinstance(self._http_options, BaseModel):
@@ -1048,68 +1079,77 @@ async def _async_upload_fd(
10481079
"""
10491080
offset = 0
10501081
# Upload the file in chunks
1051-
while True:
1052-
if isinstance(file, io.IOBase):
1053-
file_chunk = file.read(CHUNK_SIZE)
1054-
else:
1055-
file_chunk = await file.read(CHUNK_SIZE)
1056-
chunk_size = 0
1057-
if file_chunk:
1058-
chunk_size = len(file_chunk)
1059-
upload_command = 'upload'
1060-
# If last chunk, finalize the upload.
1061-
if chunk_size + offset >= upload_size:
1062-
upload_command += ', finalize'
1063-
http_options = http_options if http_options else self._http_options
1064-
timeout = (
1065-
http_options.get('timeout')
1066-
if isinstance(http_options, dict)
1067-
else http_options.timeout
1068-
)
1069-
if timeout is None:
1070-
# Per request timeout is not configured. Check the global timeout.
1082+
async with aiohttp.ClientSession() as session:
1083+
while True:
1084+
if isinstance(file, io.IOBase):
1085+
file_chunk = file.read(CHUNK_SIZE)
1086+
else:
1087+
file_chunk = await file.read(CHUNK_SIZE)
1088+
chunk_size = 0
1089+
if file_chunk:
1090+
chunk_size = len(file_chunk)
1091+
upload_command = 'upload'
1092+
# If last chunk, finalize the upload.
1093+
if chunk_size + offset >= upload_size:
1094+
upload_command += ', finalize'
1095+
http_options = http_options if http_options else self._http_options
10711096
timeout = (
1072-
self._http_options.timeout
1073-
if isinstance(self._http_options, dict)
1074-
else self._http_options.timeout
1097+
http_options.get('timeout')
1098+
if isinstance(http_options, dict)
1099+
else http_options.timeout
10751100
)
1076-
timeout_in_seconds = _get_timeout_in_seconds(timeout)
1077-
upload_headers = {
1078-
'X-Goog-Upload-Command': upload_command,
1079-
'X-Goog-Upload-Offset': str(offset),
1080-
'Content-Length': str(chunk_size),
1081-
}
1082-
_populate_server_timeout_header(upload_headers, timeout_in_seconds)
1083-
1084-
retry_count = 0
1085-
while retry_count < MAX_RETRY_COUNT:
1086-
response = await self._async_httpx_client.request(
1087-
method='POST',
1088-
url=upload_url,
1089-
content=file_chunk,
1090-
headers=upload_headers,
1091-
timeout=timeout_in_seconds,
1092-
)
1093-
if response.headers.get('x-goog-upload-status'):
1094-
break
1095-
delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
1096-
retry_count += 1
1097-
time.sleep(delay_seconds)
1098-
1099-
offset += chunk_size
1100-
if response.headers.get('x-goog-upload-status') != 'active':
1101-
break # upload is complete or it has been interrupted.
1101+
if timeout is None:
1102+
# Per request timeout is not configured. Check the global timeout.
1103+
timeout = (
1104+
self._http_options.timeout
1105+
if isinstance(self._http_options, dict)
1106+
else self._http_options.timeout
1107+
)
1108+
timeout_in_seconds = _get_timeout_in_seconds(timeout)
1109+
upload_headers = {
1110+
'X-Goog-Upload-Command': upload_command,
1111+
'X-Goog-Upload-Offset': str(offset),
1112+
'Content-Length': str(chunk_size),
1113+
}
1114+
_populate_server_timeout_header(upload_headers, timeout_in_seconds)
1115+
1116+
retry_count = 0
1117+
response = None
1118+
while retry_count < MAX_RETRY_COUNT:
1119+
response = await session.request(
1120+
method='POST',
1121+
url=upload_url,
1122+
data=file_chunk,
1123+
headers=upload_headers,
1124+
timeout=aiohttp.ClientTimeout(connect=timeout_in_seconds),
1125+
)
11021126

1103-
if upload_size <= offset: # Status is not finalized.
1127+
if response.headers.get('X-Goog-Upload-Status'):
1128+
break
1129+
delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
1130+
retry_count += 1
1131+
time.sleep(delay_seconds)
1132+
1133+
offset += chunk_size
1134+
if (
1135+
response is not None
1136+
and response.headers.get('X-Goog-Upload-Status') != 'active'
1137+
):
1138+
break # upload is complete or it has been interrupted.
1139+
1140+
if upload_size <= offset: # Status is not finalized.
1141+
raise ValueError(
1142+
f'All content has been uploaded, but the upload status is not'
1143+
f' finalized.'
1144+
)
1145+
if (
1146+
response is not None
1147+
and response.headers.get('X-Goog-Upload-Status') != 'final'
1148+
):
11041149
raise ValueError(
1105-
'All content has been uploaded, but the upload status is not'
1106-
f' finalized.'
1150+
'Failed to upload file: Upload status is not finalized.'
11071151
)
1108-
if response.headers.get('x-goog-upload-status') != 'final':
1109-
raise ValueError(
1110-
'Failed to upload file: Upload status is not finalized.'
1111-
)
1112-
return HttpResponse(response.headers, response_stream=[response.text])
1152+
return HttpResponse(response.headers, response_stream=[await response.text()])
11131153

11141154
async def async_download_file(
11151155
self,
@@ -1137,18 +1177,19 @@ async def async_download_file(
11371177
else:
11381178
data = http_request.data
11391179

1140-
response = await self._async_httpx_client.request(
1141-
method=http_request.method,
1142-
url=http_request.url,
1143-
headers=http_request.headers,
1144-
content=data,
1145-
timeout=http_request.timeout,
1146-
)
1147-
await errors.APIError.raise_for_async_response(response)
1180+
async with aiohttp.ClientSession() as session:
1181+
response = await session.request(
1182+
method=http_request.method,
1183+
url=http_request.url,
1184+
headers=http_request.headers,
1185+
data=data,
1186+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1187+
)
1188+
await errors.APIError.raise_for_async_response(response)
11481189

1149-
return HttpResponse(
1150-
response.headers, byte_stream=[response.read()]
1151-
).byte_stream[0]
1190+
return HttpResponse(
1191+
response.headers, byte_stream=[await response.read()]
1192+
).byte_stream[0]
11521193

11531194
# This method does nothing in the real api client. It is used in the
11541195
# replay_api_client to verify the response from the SDK method matches the

0 commit comments

Comments
 (0)