|
20 | 20 | """
|
21 | 21 |
|
22 | 22 | import asyncio
|
23 |
| -from collections.abc import Awaitable, Generator |
| 23 | +from collections.abc import Generator |
24 | 24 | import copy
|
25 | 25 | from dataclasses import dataclass
|
26 |
| -import datetime |
27 |
| -import http |
28 | 26 | import io
|
29 | 27 | import json
|
30 | 28 | import logging
|
31 | 29 | import math
|
| 30 | +from multidict import CIMultiDictProxy |
32 | 31 | import os
|
33 | 32 | import ssl
|
34 | 33 | import sys
|
35 | 34 | import threading
|
36 | 35 | import time
|
37 |
| -from typing import Any, AsyncIterator, Optional, Tuple, Union |
| 36 | +from typing import Any, AsyncIterator, Optional, Tuple, Union, cast |
38 | 37 | from urllib.parse import urlparse
|
39 | 38 | from urllib.parse import urlunparse
|
40 | 39 |
|
|
44 | 43 | import google.auth.credentials
|
45 | 44 | from google.auth.credentials import Credentials
|
46 | 45 | from google.auth.transport.requests import Request
|
| 46 | +import aiohttp # type: ignore |
47 | 47 | import httpx
|
48 | 48 | from pydantic import BaseModel
|
49 | 49 | from pydantic import Field
|
|
53 | 53 | from . import errors
|
54 | 54 | from . import version
|
55 | 55 | from .types import HttpOptions
|
56 |
| -from .types import HttpOptionsDict |
57 | 56 | from .types import HttpOptionsOrDict
|
58 | 57 |
|
59 | 58 |
|
@@ -216,7 +215,7 @@ class HttpResponse:
|
216 | 215 |
|
217 | 216 | def __init__(
|
218 | 217 | self,
|
219 |
| - headers: Union[dict[str, str], httpx.Headers], |
| 218 | + headers: Union[dict[str, str], httpx.Headers, "CIMultiDictProxy[str]"], |
220 | 219 | response_stream: Union[Any, str] = None,
|
221 | 220 | byte_stream: Union[Any, bytes] = None,
|
222 | 221 | ):
|
@@ -271,8 +270,8 @@ async def async_segments(self) -> AsyncIterator[Any]:
|
271 | 270 | yield c
|
272 | 271 | else:
|
273 | 272 | # Iterator of objects retrieved from the API.
|
274 |
| - if hasattr(self.response_stream, 'aiter_lines'): |
275 |
| - async for chunk in self.response_stream.aiter_lines(): |
| 273 | + if hasattr(self.response_stream, 'content'): |
| 274 | + async for chunk in self.response_stream.content.iter_any(): |
276 | 275 | # This is httpx.Response.
|
277 | 276 | if chunk:
|
278 | 277 | # In async streaming mode, the chunk of JSON is prefixed with
|
@@ -324,26 +323,6 @@ def __del__(self) -> None:
|
324 | 323 | pass
|
325 | 324 |
|
326 | 325 |
|
327 |
| -class AsyncHttpxClient(httpx.AsyncClient): |
328 |
| - """Async httpx client.""" |
329 |
| - |
330 |
| - def __init__(self, **kwargs: Any) -> None: |
331 |
| - """Initializes the httpx client.""" |
332 |
| - kwargs.setdefault('follow_redirects', True) |
333 |
| - super().__init__(**kwargs) |
334 |
| - |
335 |
| - def __del__(self) -> None: |
336 |
| - try: |
337 |
| - if self.is_closed: |
338 |
| - return |
339 |
| - except Exception: |
340 |
| - pass |
341 |
| - try: |
342 |
| - asyncio.get_running_loop().create_task(self.aclose()) |
343 |
| - except Exception: |
344 |
| - pass |
345 |
| - |
346 |
| - |
347 | 326 | class BaseApiClient:
|
348 | 327 | """Client for calling HTTP APIs sending and receiving JSON."""
|
349 | 328 |
|
@@ -480,12 +459,11 @@ def __init__(
|
480 | 459 | if self._http_options.headers is not None:
|
481 | 460 | _append_library_version_headers(self._http_options.headers)
|
482 | 461 |
|
483 |
| - client_args, async_client_args = self._ensure_ssl_ctx(self._http_options) |
| 462 | + client_args, _ = self._ensure_httpx_ssl_ctx(self._http_options) |
484 | 463 | self._httpx_client = SyncHttpxClient(**client_args)
|
485 |
| - self._async_httpx_client = AsyncHttpxClient(**async_client_args) |
486 | 464 |
|
487 | 465 | @staticmethod
|
488 |
| - def _ensure_ssl_ctx(options: HttpOptions) -> ( |
| 466 | + def _ensure_httpx_ssl_ctx(options: HttpOptions) -> ( |
489 | 467 | Tuple[dict[str, Any], dict[str, Any]]):
|
490 | 468 | """Ensures the SSL context is present in the client args.
|
491 | 469 |
|
@@ -541,6 +519,58 @@ def _maybe_set(
|
541 | 519 | _maybe_set(async_args, ctx),
|
542 | 520 | )
|
543 | 521 |
|
| 522 | + @staticmethod |
| 523 | + def _ensure_aiohttp_ssl_ctx(options: HttpOptions) -> dict[str, Any]: |
| 524 | + """Ensures the SSL context is present in the async client args. |
| 525 | +
|
| 526 | + Creates a default SSL context if one is not provided. |
| 527 | +
|
| 528 | + Args: |
| 529 | + options: The http options to check for SSL context. |
| 530 | +
|
| 531 | + Returns: |
| 532 | + An async aiohttp ClientSession._request args. |
| 533 | + """ |
| 534 | + |
| 535 | + verify = 'verify' |
| 536 | + async_args = options.async_client_args |
| 537 | + ctx = async_args.get(verify) if async_args else None |
| 538 | + |
| 539 | + if not ctx: |
| 540 | + # Initialize the SSL context for the httpx client. |
| 541 | + # Unlike requests, the aiohttp package does not automatically pull in the |
| 542 | + # environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be |
| 543 | + # enabled explicitly. Instead of 'verify' at client level in httpx, |
| 544 | + # aiohttp uses 'ssl' at request level. |
| 545 | + ctx = ssl.create_default_context( |
| 546 | + cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), |
| 547 | + capath=os.environ.get('SSL_CERT_DIR'), |
| 548 | + ) |
| 549 | + |
| 550 | + def _maybe_set( |
| 551 | + args: Optional[dict[str, Any]], |
| 552 | + ctx: ssl.SSLContext, |
| 553 | + ) -> dict[str, Any]: |
| 554 | + """Sets the SSL context in the client args if not set. |
| 555 | +
|
| 556 | + Does not override the SSL context if it is already set. |
| 557 | +
|
| 558 | + Args: |
| 559 | + args: The client args to to check for SSL context. |
| 560 | + ctx: The SSL context to set. |
| 561 | +
|
| 562 | + Returns: |
| 563 | + The client args with the SSL context included. |
| 564 | + """ |
| 565 | + if not args or not args.get(verify): |
| 566 | + args = (args or {}).copy() |
| 567 | + args['ssl'] = ctx |
| 568 | + else: |
| 569 | + args['ssl'] = args.pop(verify) |
| 570 | + return args |
| 571 | + |
| 572 | + return _maybe_set(async_args, ctx) |
| 573 | + |
544 | 574 | def _websocket_base_url(self) -> str:
|
545 | 575 | url_parts = urlparse(self._http_options.base_url)
|
546 | 576 | return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
|
@@ -736,34 +766,35 @@ async def _async_request(
|
736 | 766 | else:
|
737 | 767 | data = http_request.data
|
738 | 768 |
|
| 769 | + async_client_args = self._ensure_aiohttp_ssl_ctx(self._http_options) |
739 | 770 | if stream:
|
740 |
| - httpx_request = self._async_httpx_client.build_request( |
| 771 | + session = aiohttp.ClientSession() |
| 772 | + response = await session.request( |
741 | 773 | method=http_request.method,
|
742 | 774 | url=http_request.url,
|
743 |
| - content=data, |
744 | 775 | headers=http_request.headers,
|
745 |
| - timeout=http_request.timeout, |
746 |
| - ) |
747 |
| - response = await self._async_httpx_client.send( |
748 |
| - httpx_request, |
749 |
| - stream=stream, |
| 776 | + data=data, |
| 777 | + timeout=aiohttp.ClientTimeout(connect=http_request.timeout), |
| 778 | + **async_client_args, |
750 | 779 | )
|
751 | 780 | await errors.APIError.raise_for_async_response(response)
|
752 | 781 | return HttpResponse(
|
753 |
| - response.headers, response if stream else [response.text] |
| 782 | + response.headers, response |
754 | 783 | )
|
755 | 784 | else:
|
756 |
| - response = await self._async_httpx_client.request( |
757 |
| - method=http_request.method, |
758 |
| - url=http_request.url, |
759 |
| - headers=http_request.headers, |
760 |
| - content=data, |
761 |
| - timeout=http_request.timeout, |
762 |
| - ) |
763 |
| - await errors.APIError.raise_for_async_response(response) |
764 |
| - return HttpResponse( |
765 |
| - response.headers, response if stream else [response.text] |
766 |
| - ) |
| 785 | + async with aiohttp.ClientSession() as session: # |
| 786 | + response = await session.request( |
| 787 | + method=http_request.method, |
| 788 | + url=http_request.url, |
| 789 | + headers=http_request.headers, |
| 790 | + data=data, |
| 791 | + timeout=aiohttp.ClientTimeout(connect=http_request.timeout), |
| 792 | + **async_client_args, |
| 793 | + ) |
| 794 | + await errors.APIError.raise_for_async_response(response) |
| 795 | + return HttpResponse( |
| 796 | + response.headers, response if stream else [await response.text()] |
| 797 | + ) |
767 | 798 |
|
768 | 799 | def get_read_only_http_options(self) -> dict[str, Any]:
|
769 | 800 | if isinstance(self._http_options, BaseModel):
|
@@ -1048,68 +1079,77 @@ async def _async_upload_fd(
|
1048 | 1079 | """
|
1049 | 1080 | offset = 0
|
1050 | 1081 | # Upload the file in chunks
|
1051 |
| - while True: |
1052 |
| - if isinstance(file, io.IOBase): |
1053 |
| - file_chunk = file.read(CHUNK_SIZE) |
1054 |
| - else: |
1055 |
| - file_chunk = await file.read(CHUNK_SIZE) |
1056 |
| - chunk_size = 0 |
1057 |
| - if file_chunk: |
1058 |
| - chunk_size = len(file_chunk) |
1059 |
| - upload_command = 'upload' |
1060 |
| - # If last chunk, finalize the upload. |
1061 |
| - if chunk_size + offset >= upload_size: |
1062 |
| - upload_command += ', finalize' |
1063 |
| - http_options = http_options if http_options else self._http_options |
1064 |
| - timeout = ( |
1065 |
| - http_options.get('timeout') |
1066 |
| - if isinstance(http_options, dict) |
1067 |
| - else http_options.timeout |
1068 |
| - ) |
1069 |
| - if timeout is None: |
1070 |
| - # Per request timeout is not configured. Check the global timeout. |
| 1082 | + async with aiohttp.ClientSession() as session: |
| 1083 | + while True: |
| 1084 | + if isinstance(file, io.IOBase): |
| 1085 | + file_chunk = file.read(CHUNK_SIZE) |
| 1086 | + else: |
| 1087 | + file_chunk = await file.read(CHUNK_SIZE) |
| 1088 | + chunk_size = 0 |
| 1089 | + if file_chunk: |
| 1090 | + chunk_size = len(file_chunk) |
| 1091 | + upload_command = 'upload' |
| 1092 | + # If last chunk, finalize the upload. |
| 1093 | + if chunk_size + offset >= upload_size: |
| 1094 | + upload_command += ', finalize' |
| 1095 | + http_options = http_options if http_options else self._http_options |
1071 | 1096 | timeout = (
|
1072 |
| - self._http_options.timeout |
1073 |
| - if isinstance(self._http_options, dict) |
1074 |
| - else self._http_options.timeout |
| 1097 | + http_options.get('timeout') |
| 1098 | + if isinstance(http_options, dict) |
| 1099 | + else http_options.timeout |
1075 | 1100 | )
|
1076 |
| - timeout_in_seconds = _get_timeout_in_seconds(timeout) |
1077 |
| - upload_headers = { |
1078 |
| - 'X-Goog-Upload-Command': upload_command, |
1079 |
| - 'X-Goog-Upload-Offset': str(offset), |
1080 |
| - 'Content-Length': str(chunk_size), |
1081 |
| - } |
1082 |
| - _populate_server_timeout_header(upload_headers, timeout_in_seconds) |
1083 |
| - |
1084 |
| - retry_count = 0 |
1085 |
| - while retry_count < MAX_RETRY_COUNT: |
1086 |
| - response = await self._async_httpx_client.request( |
1087 |
| - method='POST', |
1088 |
| - url=upload_url, |
1089 |
| - content=file_chunk, |
1090 |
| - headers=upload_headers, |
1091 |
| - timeout=timeout_in_seconds, |
1092 |
| - ) |
1093 |
| - if response.headers.get('x-goog-upload-status'): |
1094 |
| - break |
1095 |
| - delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count) |
1096 |
| - retry_count += 1 |
1097 |
| - time.sleep(delay_seconds) |
1098 |
| - |
1099 |
| - offset += chunk_size |
1100 |
| - if response.headers.get('x-goog-upload-status') != 'active': |
1101 |
| - break # upload is complete or it has been interrupted. |
| 1101 | + if timeout is None: |
| 1102 | + # Per request timeout is not configured. Check the global timeout. |
| 1103 | + timeout = ( |
| 1104 | + self._http_options.timeout |
| 1105 | + if isinstance(self._http_options, dict) |
| 1106 | + else self._http_options.timeout |
| 1107 | + ) |
| 1108 | + timeout_in_seconds = _get_timeout_in_seconds(timeout) |
| 1109 | + upload_headers = { |
| 1110 | + 'X-Goog-Upload-Command': upload_command, |
| 1111 | + 'X-Goog-Upload-Offset': str(offset), |
| 1112 | + 'Content-Length': str(chunk_size), |
| 1113 | + } |
| 1114 | + _populate_server_timeout_header(upload_headers, timeout_in_seconds) |
| 1115 | + |
| 1116 | + retry_count = 0 |
| 1117 | + response = None |
| 1118 | + while retry_count < MAX_RETRY_COUNT: |
| 1119 | + response = await session.request( |
| 1120 | + method='POST', |
| 1121 | + url=upload_url, |
| 1122 | + data=file_chunk, |
| 1123 | + headers=upload_headers, |
| 1124 | + timeout=aiohttp.ClientTimeout(connect=timeout_in_seconds), |
| 1125 | + ) |
1102 | 1126 |
|
1103 |
| - if upload_size <= offset: # Status is not finalized. |
| 1127 | + if response.headers.get('X-Goog-Upload-Status'): |
| 1128 | + break |
| 1129 | + delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count) |
| 1130 | + retry_count += 1 |
| 1131 | + time.sleep(delay_seconds) |
| 1132 | + |
| 1133 | + offset += chunk_size |
| 1134 | + if ( |
| 1135 | + response is not None |
| 1136 | + and response.headers.get('X-Goog-Upload-Status') != 'active' |
| 1137 | + ): |
| 1138 | + break # upload is complete or it has been interrupted. |
| 1139 | + |
| 1140 | + if upload_size <= offset: # Status is not finalized. |
| 1141 | + raise ValueError( |
| 1142 | + f'All content has been uploaded, but the upload status is not' |
| 1143 | + f' finalized.' |
| 1144 | + ) |
| 1145 | + if ( |
| 1146 | + response is not None |
| 1147 | + and response.headers.get('X-Goog-Upload-Status') != 'final' |
| 1148 | + ): |
1104 | 1149 | raise ValueError(
|
1105 |
| - 'All content has been uploaded, but the upload status is not' |
1106 |
| - f' finalized.' |
| 1150 | + 'Failed to upload file: Upload status is not finalized.' |
1107 | 1151 | )
|
1108 |
| - if response.headers.get('x-goog-upload-status') != 'final': |
1109 |
| - raise ValueError( |
1110 |
| - 'Failed to upload file: Upload status is not finalized.' |
1111 |
| - ) |
1112 |
| - return HttpResponse(response.headers, response_stream=[response.text]) |
| 1152 | + return HttpResponse(response.headers, response_stream=[await response.text()]) |
1113 | 1153 |
|
1114 | 1154 | async def async_download_file(
|
1115 | 1155 | self,
|
@@ -1137,18 +1177,19 @@ async def async_download_file(
|
1137 | 1177 | else:
|
1138 | 1178 | data = http_request.data
|
1139 | 1179 |
|
1140 |
| - response = await self._async_httpx_client.request( |
1141 |
| - method=http_request.method, |
1142 |
| - url=http_request.url, |
1143 |
| - headers=http_request.headers, |
1144 |
| - content=data, |
1145 |
| - timeout=http_request.timeout, |
1146 |
| - ) |
1147 |
| - await errors.APIError.raise_for_async_response(response) |
| 1180 | + async with aiohttp.ClientSession() as session: |
| 1181 | + response = await session.request( |
| 1182 | + method=http_request.method, |
| 1183 | + url=http_request.url, |
| 1184 | + headers=http_request.headers, |
| 1185 | + data=data, |
| 1186 | + timeout=aiohttp.ClientTimeout(connect=http_request.timeout), |
| 1187 | + ) |
| 1188 | + await errors.APIError.raise_for_async_response(response) |
1148 | 1189 |
|
1149 |
| - return HttpResponse( |
1150 |
| - response.headers, byte_stream=[response.read()] |
1151 |
| - ).byte_stream[0] |
| 1190 | + return HttpResponse( |
| 1191 | + response.headers, byte_stream=[await response.read()] |
| 1192 | + ).byte_stream[0] |
1152 | 1193 |
|
1153 | 1194 | # This method does nothing in the real api client. It is used in the
|
1154 | 1195 | # replay_api_client to verify the response from the SDK method matches the
|
|
0 commit comments