Skip to content

Commit

Permalink
AIOHTTPTransport default ssl cert validation add warning (#530)
Browse files Browse the repository at this point in the history
  • Loading branch information
leszekhanusz authored Feb 18, 2025
1 parent b066e89 commit 68ae2e6
Show file tree
Hide file tree
Showing 11 changed files with 628 additions and 32 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include tox.ini

include gql/py.typed

recursive-include tests *.py *.graphql *.cnf *.yaml *.pem
recursive-include tests *.py *.graphql *.cnf *.yaml *.pem *.crt
recursive-include docs *.txt *.rst conf.py Makefile make.bat
recursive-include docs/code_examples *.py

Expand Down
30 changes: 27 additions & 3 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,19 @@
import io
import json
import logging
import warnings
from ssl import SSLContext
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
cast,
)

import aiohttp
from aiohttp.client_exceptions import ClientResponseError
Expand Down Expand Up @@ -46,7 +57,7 @@ def __init__(
headers: Optional[LooseHeaders] = None,
cookies: Optional[LooseCookies] = None,
auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None,
ssl: Union[SSLContext, bool, Fingerprint] = False,
ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning",
timeout: Optional[int] = None,
ssl_close_timeout: Optional[Union[int, float]] = 10,
json_serialize: Callable = json.dumps,
Expand Down Expand Up @@ -77,7 +88,20 @@ def __init__(
self.headers: Optional[LooseHeaders] = headers
self.cookies: Optional[LooseCookies] = cookies
self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth
self.ssl: Union[SSLContext, bool, Fingerprint] = ssl

if ssl == "ssl_warning":
ssl = False
if str(url).startswith("https"):
warnings.warn(
"WARNING: By default, AIOHTTPTransport does not verify"
" ssl certificates. This will be fixed in the next major version."
" You can set ssl=True to force the ssl certificate verification"
" or ssl=False to disable this warning"
)

self.ssl: Union[SSLContext, bool, Fingerprint] = cast(
Union[SSLContext, bool, Fingerprint], ssl
)
self.timeout: Optional[int] = timeout
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
self.client_session_args = client_session_args
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,29 @@ def get_localhost_ssl_context():
return (testcert, ssl_context)


def get_localhost_ssl_context_client():
"""
Create a client-side SSL context that verifies the specific self-signed certificate
used for our test.
"""
# Get the certificate from the server setup
cert_path = bytes(pathlib.Path(__file__).with_name("test_localhost_client.crt"))

# Create client SSL context
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

# Load just the certificate part as a trusted CA
ssl_context.load_verify_locations(cafile=cert_path)

# Require certificate verification
ssl_context.verify_mode = ssl.CERT_REQUIRED

# Enable hostname checking for localhost
ssl_context.check_hostname = True

return cert_path, ssl_context


class WebSocketServer:
"""Websocket server on localhost on a free port.
Expand Down
84 changes: 81 additions & 3 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
TransportServerError,
)

from .conftest import TemporaryFile, strip_braces_spaces
from .conftest import (
TemporaryFile,
get_localhost_ssl_context_client,
strip_braces_spaces,
)

query1_str = """
query getContinents {
Expand Down Expand Up @@ -1285,7 +1289,10 @@ async def handler(request):

@pytest.mark.asyncio
@pytest.mark.parametrize("ssl_close_timeout", [0, 10])
async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout):
@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"])
async def test_aiohttp_query_https(
event_loop, ssl_aiohttp_server, ssl_close_timeout, verify_https
):
from aiohttp import web
from gql.transport.aiohttp import AIOHTTPTransport

Expand All @@ -1300,8 +1307,20 @@ async def handler(request):

assert str(url).startswith("https://")

extra_args = {}

if verify_https == "cert_provided":
_, ssl_context = get_localhost_ssl_context_client()

extra_args["ssl"] = ssl_context
elif verify_https == "disabled":
extra_args["ssl"] = False

transport = AIOHTTPTransport(
url=url, timeout=10, ssl_close_timeout=ssl_close_timeout
url=url,
timeout=10,
ssl_close_timeout=ssl_close_timeout,
**extra_args,
)

async with Client(transport=transport) as session:
Expand All @@ -1318,6 +1337,65 @@ async def handler(request):
assert africa["code"] == "AF"


@pytest.mark.skip(reason="We will change the default to fix this in a future version")
@pytest.mark.asyncio
async def test_aiohttp_query_https_self_cert_fail(event_loop, ssl_aiohttp_server):
"""By default, we should verify the ssl certificate"""
from aiohttp.client_exceptions import ClientConnectorCertificateError
from aiohttp import web
from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await ssl_aiohttp_server(app)

url = server.make_url("/")

assert str(url).startswith("https://")

transport = AIOHTTPTransport(url=url, timeout=10)

with pytest.raises(ClientConnectorCertificateError) as exc_info:
async with Client(transport=transport) as session:
query = gql(query1_str)

# Execute query asynchronously
await session.execute(query)

expected_error = "certificate verify failed: self-signed certificate"

assert expected_error in str(exc_info.value)
assert transport.session is None


@pytest.mark.asyncio
async def test_aiohttp_query_https_self_cert_warn(event_loop, ssl_aiohttp_server):
from aiohttp import web
from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await ssl_aiohttp_server(app)

url = server.make_url("/")

assert str(url).startswith("https://")

expected_warning = (
"WARNING: By default, AIOHTTPTransport does not verify ssl certificates."
" This will be fixed in the next major version."
)

with pytest.warns(Warning, match=expected_warning):
AIOHTTPTransport(url=url, timeout=10)


@pytest.mark.asyncio
async def test_aiohttp_error_fetching_schema(event_loop, aiohttp_server):
from aiohttp import web
Expand Down
63 changes: 57 additions & 6 deletions tests/test_aiohttp_websocket_query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import json
import ssl
import sys
from typing import Dict, Mapping

Expand All @@ -14,7 +13,7 @@
TransportServerError,
)

from .conftest import MS, WebSocketServerHelper
from .conftest import MS, WebSocketServerHelper, get_localhost_ssl_context_client

# Marking all tests in this file with the aiohttp AND websockets marker
pytestmark = pytest.mark.aiohttp
Expand Down Expand Up @@ -92,8 +91,9 @@ async def test_aiohttp_websocket_starting_client_in_context_manager(
@pytest.mark.websockets
@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True)
@pytest.mark.parametrize("ssl_close_timeout", [0, 10])
@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"])
async def test_aiohttp_websocket_using_ssl_connection(
event_loop, ws_ssl_server, ssl_close_timeout
event_loop, ws_ssl_server, ssl_close_timeout, verify_https
):

from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport
Expand All @@ -103,11 +103,19 @@ async def test_aiohttp_websocket_using_ssl_connection(
url = f"wss://{server.hostname}:{server.port}/graphql"
print(f"url = {url}")

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(ws_ssl_server.testcert)
extra_args = {}

if verify_https == "cert_provided":
_, ssl_context = get_localhost_ssl_context_client()

extra_args["ssl"] = ssl_context
elif verify_https == "disabled":
extra_args["ssl"] = False

transport = AIOHTTPWebsocketsTransport(
url=url, ssl=ssl_context, ssl_close_timeout=ssl_close_timeout
url=url,
ssl_close_timeout=ssl_close_timeout,
**extra_args,
)

async with Client(transport=transport) as session:
Expand All @@ -130,6 +138,49 @@ async def test_aiohttp_websocket_using_ssl_connection(
assert transport.websocket is None


@pytest.mark.asyncio
@pytest.mark.websockets
@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True)
@pytest.mark.parametrize("ssl_close_timeout", [10])
@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"])
async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail(
event_loop, ws_ssl_server, ssl_close_timeout, verify_https
):

from aiohttp.client_exceptions import ClientConnectorCertificateError
from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport

server = ws_ssl_server

url = f"wss://{server.hostname}:{server.port}/graphql"
print(f"url = {url}")

extra_args = {}

if verify_https == "explicitely_enabled":
extra_args["ssl"] = True

transport = AIOHTTPWebsocketsTransport(
url=url,
ssl_close_timeout=ssl_close_timeout,
**extra_args,
)

with pytest.raises(ClientConnectorCertificateError) as exc_info:
async with Client(transport=transport) as session:

query1 = gql(query1_str)

await session.execute(query1)

expected_error = "certificate verify failed: self-signed certificate"

assert expected_error in str(exc_info.value)

# Check client is disconnect here
assert transport.websocket is None


@pytest.mark.asyncio
@pytest.mark.websockets
@pytest.mark.parametrize("server", [server1_answers], indirect=True)
Expand Down
Loading

0 comments on commit 68ae2e6

Please sign in to comment.