Skip to content

feat: Use aiohttp in async APIs to lower latency #930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 159 additions & 118 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,20 @@
"""

import asyncio
from collections.abc import Awaitable, Generator
from collections.abc import Generator
import copy
from dataclasses import dataclass
import datetime
import http
import io
import json
import logging
import math
from multidict import CIMultiDictProxy
import os
import ssl
import sys
import threading
import time
from typing import Any, AsyncIterator, Optional, Tuple, Union
from typing import Any, AsyncIterator, Optional, Tuple, Union, cast
from urllib.parse import urlparse
from urllib.parse import urlunparse

Expand All @@ -44,6 +43,7 @@
import google.auth.credentials
from google.auth.credentials import Credentials
from google.auth.transport.requests import Request
import aiohttp # type: ignore
import httpx
from pydantic import BaseModel
from pydantic import Field
Expand All @@ -53,7 +53,6 @@
from . import errors
from . import version
from .types import HttpOptions
from .types import HttpOptionsDict
from .types import HttpOptionsOrDict


Expand Down Expand Up @@ -216,7 +215,7 @@ class HttpResponse:

def __init__(
self,
headers: Union[dict[str, str], httpx.Headers],
headers: Union[dict[str, str], httpx.Headers, "CIMultiDictProxy[str]"],
response_stream: Union[Any, str] = None,
byte_stream: Union[Any, bytes] = None,
):
Expand Down Expand Up @@ -271,8 +270,8 @@ async def async_segments(self) -> AsyncIterator[Any]:
yield c
else:
# Iterator of objects retrieved from the API.
if hasattr(self.response_stream, 'aiter_lines'):
async for chunk in self.response_stream.aiter_lines():
if hasattr(self.response_stream, 'content'):
async for chunk in self.response_stream.content.iter_any():
# This is httpx.Response.
if chunk:
# In async streaming mode, the chunk of JSON is prefixed with
Expand Down Expand Up @@ -324,26 +323,6 @@ def __del__(self) -> None:
pass


class AsyncHttpxClient(httpx.AsyncClient):
"""Async httpx client."""

def __init__(self, **kwargs: Any) -> None:
"""Initializes the httpx client."""
kwargs.setdefault('follow_redirects', True)
super().__init__(**kwargs)

def __del__(self) -> None:
try:
if self.is_closed:
return
except Exception:
pass
try:
asyncio.get_running_loop().create_task(self.aclose())
except Exception:
pass


class BaseApiClient:
"""Client for calling HTTP APIs sending and receiving JSON."""

Expand Down Expand Up @@ -480,12 +459,11 @@ def __init__(
if self._http_options.headers is not None:
_append_library_version_headers(self._http_options.headers)

client_args, async_client_args = self._ensure_ssl_ctx(self._http_options)
client_args, _ = self._ensure_httpx_ssl_ctx(self._http_options)
self._httpx_client = SyncHttpxClient(**client_args)
self._async_httpx_client = AsyncHttpxClient(**async_client_args)

@staticmethod
def _ensure_ssl_ctx(options: HttpOptions) -> (
def _ensure_httpx_ssl_ctx(options: HttpOptions) -> (
Tuple[dict[str, Any], dict[str, Any]]):
"""Ensures the SSL context is present in the client args.

Expand Down Expand Up @@ -541,6 +519,58 @@ def _maybe_set(
_maybe_set(async_args, ctx),
)

@staticmethod
def _ensure_aiohttp_ssl_ctx(options: HttpOptions) -> dict[str, Any]:
"""Ensures the SSL context is present in the async client args.

Creates a default SSL context if one is not provided.

Args:
options: The http options to check for SSL context.

Returns:
An async aiohttp ClientSession._request args.
"""

verify = 'verify'
async_args = options.async_client_args
ctx = async_args.get(verify) if async_args else None

if not ctx:
# Initialize the SSL context for the httpx client.
# Unlike requests, the aiohttp package does not automatically pull in the
# environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
# enabled explicitly. Instead of 'verify' at client level in httpx,
# aiohttp uses 'ssl' at request level.
ctx = ssl.create_default_context(
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
capath=os.environ.get('SSL_CERT_DIR'),
)

def _maybe_set(
args: Optional[dict[str, Any]],
ctx: ssl.SSLContext,
) -> dict[str, Any]:
"""Sets the SSL context in the client args if not set.

Does not override the SSL context if it is already set.

Args:
args: The client args to to check for SSL context.
ctx: The SSL context to set.

Returns:
The client args with the SSL context included.
"""
if not args or not args.get(verify):
args = (args or {}).copy()
args['ssl'] = ctx
else:
args['ssl'] = args.pop(verify)
return args

return _maybe_set(async_args, ctx)

def _websocket_base_url(self) -> str:
url_parts = urlparse(self._http_options.base_url)
return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
Expand Down Expand Up @@ -736,34 +766,35 @@ async def _async_request(
else:
data = http_request.data

async_client_args = self._ensure_aiohttp_ssl_ctx(self._http_options)
if stream:
httpx_request = self._async_httpx_client.build_request(
session = aiohttp.ClientSession()
response = await session.request(
method=http_request.method,
url=http_request.url,
content=data,
headers=http_request.headers,
timeout=http_request.timeout,
)
response = await self._async_httpx_client.send(
httpx_request,
stream=stream,
data=data,
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
**async_client_args,
)
await errors.APIError.raise_for_async_response(response)
return HttpResponse(
response.headers, response if stream else [response.text]
response.headers, response
)
else:
response = await self._async_httpx_client.request(
method=http_request.method,
url=http_request.url,
headers=http_request.headers,
content=data,
timeout=http_request.timeout,
)
await errors.APIError.raise_for_async_response(response)
return HttpResponse(
response.headers, response if stream else [response.text]
)
async with aiohttp.ClientSession() as session: #
response = await session.request(
method=http_request.method,
url=http_request.url,
headers=http_request.headers,
data=data,
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
**async_client_args,
)
await errors.APIError.raise_for_async_response(response)
return HttpResponse(
response.headers, response if stream else [await response.text()]
)

def get_read_only_http_options(self) -> dict[str, Any]:
if isinstance(self._http_options, BaseModel):
Expand Down Expand Up @@ -1048,68 +1079,77 @@ async def _async_upload_fd(
"""
offset = 0
# Upload the file in chunks
while True:
if isinstance(file, io.IOBase):
file_chunk = file.read(CHUNK_SIZE)
else:
file_chunk = await file.read(CHUNK_SIZE)
chunk_size = 0
if file_chunk:
chunk_size = len(file_chunk)
upload_command = 'upload'
# If last chunk, finalize the upload.
if chunk_size + offset >= upload_size:
upload_command += ', finalize'
http_options = http_options if http_options else self._http_options
timeout = (
http_options.get('timeout')
if isinstance(http_options, dict)
else http_options.timeout
)
if timeout is None:
# Per request timeout is not configured. Check the global timeout.
async with aiohttp.ClientSession() as session:
while True:
if isinstance(file, io.IOBase):
file_chunk = file.read(CHUNK_SIZE)
else:
file_chunk = await file.read(CHUNK_SIZE)
chunk_size = 0
if file_chunk:
chunk_size = len(file_chunk)
upload_command = 'upload'
# If last chunk, finalize the upload.
if chunk_size + offset >= upload_size:
upload_command += ', finalize'
http_options = http_options if http_options else self._http_options
timeout = (
self._http_options.timeout
if isinstance(self._http_options, dict)
else self._http_options.timeout
http_options.get('timeout')
if isinstance(http_options, dict)
else http_options.timeout
)
timeout_in_seconds = _get_timeout_in_seconds(timeout)
upload_headers = {
'X-Goog-Upload-Command': upload_command,
'X-Goog-Upload-Offset': str(offset),
'Content-Length': str(chunk_size),
}
_populate_server_timeout_header(upload_headers, timeout_in_seconds)

retry_count = 0
while retry_count < MAX_RETRY_COUNT:
response = await self._async_httpx_client.request(
method='POST',
url=upload_url,
content=file_chunk,
headers=upload_headers,
timeout=timeout_in_seconds,
)
if response.headers.get('x-goog-upload-status'):
break
delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
retry_count += 1
time.sleep(delay_seconds)

offset += chunk_size
if response.headers.get('x-goog-upload-status') != 'active':
break # upload is complete or it has been interrupted.
if timeout is None:
# Per request timeout is not configured. Check the global timeout.
timeout = (
self._http_options.timeout
if isinstance(self._http_options, dict)
else self._http_options.timeout
)
timeout_in_seconds = _get_timeout_in_seconds(timeout)
upload_headers = {
'X-Goog-Upload-Command': upload_command,
'X-Goog-Upload-Offset': str(offset),
'Content-Length': str(chunk_size),
}
_populate_server_timeout_header(upload_headers, timeout_in_seconds)

retry_count = 0
response = None
while retry_count < MAX_RETRY_COUNT:
response = await session.request(
method='POST',
url=upload_url,
data=file_chunk,
headers=upload_headers,
timeout=aiohttp.ClientTimeout(connect=timeout_in_seconds),
)

if upload_size <= offset: # Status is not finalized.
if response.headers.get('X-Goog-Upload-Status'):
break
delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
retry_count += 1
time.sleep(delay_seconds)

offset += chunk_size
if (
response is not None
and response.headers.get('X-Goog-Upload-Status') != 'active'
):
break # upload is complete or it has been interrupted.

if upload_size <= offset: # Status is not finalized.
raise ValueError(
f'All content has been uploaded, but the upload status is not'
f' finalized.'
)
if (
response is not None
and response.headers.get('X-Goog-Upload-Status') != 'final'
):
raise ValueError(
'All content has been uploaded, but the upload status is not'
f' finalized.'
'Failed to upload file: Upload status is not finalized.'
)
if response.headers.get('x-goog-upload-status') != 'final':
raise ValueError(
'Failed to upload file: Upload status is not finalized.'
)
return HttpResponse(response.headers, response_stream=[response.text])
return HttpResponse(response.headers, response_stream=[await response.text()])

async def async_download_file(
self,
Expand Down Expand Up @@ -1137,18 +1177,19 @@ async def async_download_file(
else:
data = http_request.data

response = await self._async_httpx_client.request(
method=http_request.method,
url=http_request.url,
headers=http_request.headers,
content=data,
timeout=http_request.timeout,
)
await errors.APIError.raise_for_async_response(response)
async with aiohttp.ClientSession() as session:
response = await session.request(
method=http_request.method,
url=http_request.url,
headers=http_request.headers,
data=data,
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
)
await errors.APIError.raise_for_async_response(response)

return HttpResponse(
response.headers, byte_stream=[response.read()]
).byte_stream[0]
return HttpResponse(
response.headers, byte_stream=[await response.read()]
).byte_stream[0]

# This method does nothing in the real api client. It is used in the
# replay_api_client to verify the response from the SDK method matches the
Expand Down
Loading
Loading