From 68ae2e683e54b3f97fc33ca3f7dd394217bbf81d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 18 Feb 2025 13:29:37 +0100 Subject: [PATCH] AIOHTTPTransport default ssl cert validation add warning (#530) --- MANIFEST.in | 2 +- gql/transport/aiohttp.py | 30 ++++++- tests/conftest.py | 23 +++++ tests/test_aiohttp.py | 84 +++++++++++++++++- tests/test_aiohttp_websocket_query.py | 63 ++++++++++++-- tests/test_httpx.py | 114 +++++++++++++++++++++++- tests/test_httpx_async.py | 61 ++++++++++++- tests/test_localhost_client.crt | 20 +++++ tests/test_phoenix_channel_query.py | 88 +++++++++++++++++-- tests/test_requests.py | 120 +++++++++++++++++++++++++- tests/test_websocket_query.py | 55 ++++++++++-- 11 files changed, 628 insertions(+), 32 deletions(-) create mode 100644 tests/test_localhost_client.crt diff --git a/MANIFEST.in b/MANIFEST.in index ddebd0b0..ca670908 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -12,7 +12,7 @@ include tox.ini include gql/py.typed -recursive-include tests *.py *.graphql *.cnf *.yaml *.pem +recursive-include tests *.py *.graphql *.cnf *.yaml *.pem *.crt recursive-include docs *.txt *.rst conf.py Makefile make.bat recursive-include docs/code_examples *.py diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 6455e2d8..0c332205 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -3,8 +3,19 @@ import io import json import logging +import warnings from ssl import SSLContext -from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + Optional, + Tuple, + Type, + Union, + cast, +) import aiohttp from aiohttp.client_exceptions import ClientResponseError @@ -46,7 +57,7 @@ def __init__( headers: Optional[LooseHeaders] = None, cookies: Optional[LooseCookies] = None, auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None, - ssl: Union[SSLContext, bool, Fingerprint] = False, + ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning", timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, json_serialize: Callable = json.dumps, @@ -77,7 +88,20 @@ def __init__( self.headers: Optional[LooseHeaders] = headers self.cookies: Optional[LooseCookies] = cookies self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth - self.ssl: Union[SSLContext, bool, Fingerprint] = ssl + + if ssl == "ssl_warning": + ssl = False + if str(url).startswith("https"): + warnings.warn( + "WARNING: By default, AIOHTTPTransport does not verify" + " ssl certificates. This will be fixed in the next major version." + " You can set ssl=True to force the ssl certificate verification" + " or ssl=False to disable this warning" + ) + + self.ssl: Union[SSLContext, bool, Fingerprint] = cast( + Union[SSLContext, bool, Fingerprint], ssl + ) self.timeout: Optional[int] = timeout self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.client_session_args = client_session_args diff --git a/tests/conftest.py b/tests/conftest.py index c164c355..c0b2037f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,6 +156,29 @@ def get_localhost_ssl_context(): return (testcert, ssl_context) +def get_localhost_ssl_context_client(): + """ + Create a client-side SSL context that verifies the specific self-signed certificate + used for our test. + """ + # Get the certificate from the server setup + cert_path = bytes(pathlib.Path(__file__).with_name("test_localhost_client.crt")) + + # Create client SSL context + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + # Load just the certificate part as a trusted CA + ssl_context.load_verify_locations(cafile=cert_path) + + # Require certificate verification + ssl_context.verify_mode = ssl.CERT_REQUIRED + + # Enable hostname checking for localhost + ssl_context.check_hostname = True + + return cert_path, ssl_context + + class WebSocketServer: """Websocket server on localhost on a free port. diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 55b08260..81af20ff 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -14,7 +14,11 @@ TransportServerError, ) -from .conftest import TemporaryFile, strip_braces_spaces +from .conftest import ( + TemporaryFile, + get_localhost_ssl_context_client, + strip_braces_spaces, +) query1_str = """ query getContinents { @@ -1285,7 +1289,10 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) -async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout): +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_aiohttp_query_https( + event_loop, ssl_aiohttp_server, ssl_close_timeout, verify_https +): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1300,8 +1307,20 @@ async def handler(request): assert str(url).startswith("https://") + extra_args = {} + + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() + + extra_args["ssl"] = ssl_context + elif verify_https == "disabled": + extra_args["ssl"] = False + transport = AIOHTTPTransport( - url=url, timeout=10, ssl_close_timeout=ssl_close_timeout + url=url, + timeout=10, + ssl_close_timeout=ssl_close_timeout, + **extra_args, ) async with Client(transport=transport) as session: @@ -1318,6 +1337,65 @@ async def handler(request): assert africa["code"] == "AF" +@pytest.mark.skip(reason="We will change the default to fix this in a future version") +@pytest.mark.asyncio +async def test_aiohttp_query_https_self_cert_fail(event_loop, ssl_aiohttp_server): + """By default, we should verify the ssl certificate""" + from aiohttp.client_exceptions import ClientConnectorCertificateError + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + assert str(url).startswith("https://") + + transport = AIOHTTPTransport(url=url, timeout=10) + + with pytest.raises(ClientConnectorCertificateError) as exc_info: + async with Client(transport=transport) as session: + query = gql(query1_str) + + # Execute query asynchronously + await session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + assert transport.session is None + + +@pytest.mark.asyncio +async def test_aiohttp_query_https_self_cert_warn(event_loop, ssl_aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + assert str(url).startswith("https://") + + expected_warning = ( + "WARNING: By default, AIOHTTPTransport does not verify ssl certificates." + " This will be fixed in the next major version." + ) + + with pytest.warns(Warning, match=expected_warning): + AIOHTTPTransport(url=url, timeout=10) + + @pytest.mark.asyncio async def test_aiohttp_error_fetching_schema(event_loop, aiohttp_server): from aiohttp import web diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index f154386b..ff2bcf02 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -1,6 +1,5 @@ import asyncio import json -import ssl import sys from typing import Dict, Mapping @@ -14,7 +13,7 @@ TransportServerError, ) -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, WebSocketServerHelper, get_localhost_ssl_context_client # Marking all tests in this file with the aiohttp AND websockets marker pytestmark = pytest.mark.aiohttp @@ -92,8 +91,9 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( @pytest.mark.websockets @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_aiohttp_websocket_using_ssl_connection( - event_loop, ws_ssl_server, ssl_close_timeout + event_loop, ws_ssl_server, ssl_close_timeout, verify_https ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -103,11 +103,19 @@ async def test_aiohttp_websocket_using_ssl_connection( url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(ws_ssl_server.testcert) + extra_args = {} + + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() + + extra_args["ssl"] = ssl_context + elif verify_https == "disabled": + extra_args["ssl"] = False transport = AIOHTTPWebsocketsTransport( - url=url, ssl=ssl_context, ssl_close_timeout=ssl_close_timeout + url=url, + ssl_close_timeout=ssl_close_timeout, + **extra_args, ) async with Client(transport=transport) as session: @@ -130,6 +138,49 @@ async def test_aiohttp_websocket_using_ssl_connection( assert transport.websocket is None +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) +@pytest.mark.parametrize("ssl_close_timeout", [10]) +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( + event_loop, ws_ssl_server, ssl_close_timeout, verify_https +): + + from aiohttp.client_exceptions import ClientConnectorCertificateError + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + server = ws_ssl_server + + url = f"wss://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["ssl"] = True + + transport = AIOHTTPWebsocketsTransport( + url=url, + ssl_close_timeout=ssl_close_timeout, + **extra_args, + ) + + with pytest.raises(ClientConnectorCertificateError) as exc_info: + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + await session.execute(query1) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + # Check client is disconnect here + assert transport.websocket is None + + @pytest.mark.asyncio @pytest.mark.websockets @pytest.mark.parametrize("server", [server1_answers], indirect=True) diff --git a/tests/test_httpx.py b/tests/test_httpx.py index af12f717..8ef57a84 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -11,7 +11,7 @@ TransportServerError, ) -from .conftest import TemporaryFile, strip_braces_spaces +from .conftest import TemporaryFile, get_localhost_ssl_context, strip_braces_spaces # Marking all tests in this file with the httpx marker pytestmark = pytest.mark.httpx @@ -77,6 +77,118 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_httpx_query_https( + event_loop, ssl_aiohttp_server, run_sync_test, verify_https +): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = str(server.make_url("/")) + + assert str(url).startswith("https://") + + def test_code(): + extra_args = {} + + if verify_https == "cert_provided": + cert, _ = get_localhost_ssl_context() + + extra_args["verify"] = cert.decode() + elif verify_https == "disabled": + extra_args["verify"] = False + + transport = HTTPXTransport( + url=url, + **extra_args, + ) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_httpx_query_https_self_cert_fail( + event_loop, ssl_aiohttp_server, run_sync_test, verify_https +): + """By default, we should verify the ssl certificate""" + from aiohttp import web + from httpx import ConnectError + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = str(server.make_url("/")) + + assert str(url).startswith("https://") + + def test_code(): + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["verify"] = True + + transport = HTTPXTransport( + url=url, + **extra_args, + ) + + with pytest.raises(ConnectError) as exc_info: + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_cookies(event_loop, aiohttp_server, run_sync_test): diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 17be0db5..47744538 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -14,7 +14,11 @@ TransportServerError, ) -from .conftest import TemporaryFile, get_localhost_ssl_context, strip_braces_spaces +from .conftest import ( + TemporaryFile, + get_localhost_ssl_context_client, + strip_braces_spaces, +) query1_str = """ query getContinents { @@ -1162,7 +1166,8 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_https(event_loop, ssl_aiohttp_server): +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_httpx_query_https(event_loop, ssl_aiohttp_server, verify_https): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -1177,9 +1182,16 @@ async def handler(request): assert url.startswith("https://") - cert, _ = get_localhost_ssl_context() + extra_args = {} + + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() - transport = HTTPXAsyncTransport(url=url, timeout=10, verify=cert.decode()) + extra_args["verify"] = ssl_context + elif verify_https == "disabled": + extra_args["verify"] = False + + transport = HTTPXAsyncTransport(url=url, timeout=10, **extra_args) async with Client(transport=transport) as session: @@ -1195,6 +1207,47 @@ async def handler(request): assert africa["code"] == "AF" +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_httpx_query_https_self_cert_fail( + event_loop, ssl_aiohttp_server, verify_https +): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + from httpx import ConnectError + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = str(server.make_url("/")) + + assert url.startswith("https://") + + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["verify"] = True + + transport = HTTPXAsyncTransport(url=url, timeout=10, **extra_args) + + with pytest.raises(ConnectError) as exc_info: + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + await session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_error_fetching_schema(event_loop, aiohttp_server): diff --git a/tests/test_localhost_client.crt b/tests/test_localhost_client.crt new file mode 100644 index 00000000..0bbed2f5 --- /dev/null +++ b/tests/test_localhost_client.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDTTCCAjWgAwIBAgIJAJ6VG2cQlsepMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV +BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp +bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTE4MDUwNTE2NTc1NloYDzIwNjAwNTA0 +MTY1NzU2WjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM +EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAJSCtBWQ1sBZGWjNlSPXhR/PtgSnYxea+aF2 +V84YvCPL7E873xolG/n+dgXZ5YzeWVyYt7wVsFIr5AVOjiy7tlWdzqohM4epxINT +DTpZqtBQyz3huEdS9CnW7z5vaE2Ix4bDr5CIEjo4lE6IaktFuQ3pSPcArCLxJhWg +vIyLO27Bs3IZ/x8XcMOkdm0GK0a0xIEIyxCx8HjrmmXZSjIGtZraWxsu3dW8Flm8 +ep8S4+OmOMo3lRIhedp/Q2LNpHqmzcTJ9+1bLiLvMhA3m5MTG9o8PI+f2cfer92R +P32ZIxJTUC9NOlfw83sOWoTrBkxtCwE9EZbsYSVD47Egp0o4uTkCAwEAAaMwMC4w +LAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0G +CSqGSIb3DQEBCwUAA4IBAQA0imKp/rflfbDCCx78NdsR5rt0jKem2t3YPGT6tbeU ++FQz62SEdeD2OHWxpvfPf+6h3iTXJbkakr2R4lP3z7GHUe61lt3So9VHAvgbtPTH +aB1gOdThA83o0fzQtnIv67jCvE9gwPQInViZLEcm2iQEZLj6AuSvBKmluTR7vNRj +8/f2R4LsDfCWGrzk2W+deGRvSow7irS88NQ8BW8S8otgMiBx4D2UlOmQwqr6X+/r +jYIDuMb6GDKRXtBUGDokfE94hjj9u2mrNRwt8y4tqu8ZNa//yLEQ0Ow2kP3QJPLY +941VZpwRi2v/+JvI7OBYlvbOTFwM8nAk79k+Dgviygd9 +-----END CERTIFICATE----- diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index b13a8c55..666fec34 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -2,6 +2,8 @@ from gql import Client, gql +from .conftest import get_localhost_ssl_context_client + # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -56,17 +58,91 @@ async def test_phoenix_channel_query(event_loop, server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) + + query = gql(query_str) + async with Client(transport=transport) as session: + result = await session.execute(query) + + print("Client received:", result) + + +@pytest.mark.skip(reason="ssl=False is not working for now") +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_ssl_server", [query_server], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_phoenix_channel_query_ssl( + event_loop, ws_ssl_server, query_str, verify_https +): + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + server = ws_ssl_server + url = f"wss://{server.hostname}:{server.port}{path}" + + extra_args = {} + + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() + + extra_args["ssl"] = ssl_context + elif verify_https == "disabled": + extra_args["ssl"] = False + + transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", + url=url, + **extra_args, ) query = gql(query_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: result = await session.execute(query) print("Client received:", result) +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_ssl_server", [query_server], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_phoenix_channel_query_ssl_self_cert_fail( + event_loop, ws_ssl_server, query_str, verify_https +): + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + from ssl import SSLCertVerificationError + + path = "/graphql" + server = ws_ssl_server + url = f"wss://{server.hostname}:{server.port}{path}" + + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["ssl"] = True + + transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", + url=url, + **extra_args, + ) + + query = gql(query_str) + + with pytest.raises(SSLCertVerificationError) as exc_info: + async with Client(transport=transport) as session: + await session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + query2_str = """ subscription getContinents { continents { @@ -133,13 +209,11 @@ async def test_phoenix_channel_subscription(event_loop, server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) first_result = None query = gql(query_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for result in session.subscribe(query): first_result = result break diff --git a/tests/test_requests.py b/tests/test_requests.py index ba666243..95db0b3f 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -11,7 +11,11 @@ TransportServerError, ) -from .conftest import TemporaryFile, strip_braces_spaces +from .conftest import ( + TemporaryFile, + get_localhost_ssl_context_client, + strip_braces_spaces, +) # Marking all tests in this file with the requests marker pytestmark = pytest.mark.requests @@ -77,6 +81,120 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_requests_query_https( + event_loop, ssl_aiohttp_server, run_sync_test, verify_https +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + import warnings + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + with warnings.catch_warnings(): + + extra_args = {} + + if verify_https == "cert_provided": + cert_path, _ = get_localhost_ssl_context_client() + + extra_args["verify"] = cert_path + elif verify_https == "disabled": + extra_args["verify"] = False + + # Ignoring Insecure Request warning + warnings.filterwarnings("ignore") + + transport = RequestsHTTPTransport( + url=url, + **extra_args, + ) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_requests_query_https_self_cert_fail( + event_loop, ssl_aiohttp_server, run_sync_test, verify_https +): + """By default, we should verify the ssl certificate""" + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + from requests.exceptions import SSLError + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["verify"] = True + + transport = RequestsHTTPTransport( + url=url, + **extra_args, + ) + + with pytest.raises(SSLError) as exc_info: + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 9e6fd4ab..56dd150f 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -1,6 +1,5 @@ import asyncio import json -import ssl import sys from typing import Dict, Mapping @@ -14,7 +13,7 @@ TransportServerError, ) -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, WebSocketServerHelper, get_localhost_ssl_context_client # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -89,9 +88,11 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert transport.websocket is None +@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) -async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_https): import websockets from gql.transport.websockets import WebsocketsTransport @@ -100,10 +101,16 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(ws_ssl_server.testcert) + extra_args = {} - transport = WebsocketsTransport(url=url, ssl=ssl_context) + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() + + extra_args["ssl"] = ssl_context + elif verify_https == "disabled": + extra_args["ssl"] = False + + transport = WebsocketsTransport(url=url, **extra_args) async with Client(transport=transport) as session: @@ -129,6 +136,42 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): assert transport.websocket is None +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_websocket_using_ssl_connection_self_cert_fail( + event_loop, ws_ssl_server, verify_https +): + from gql.transport.websockets import WebsocketsTransport + from ssl import SSLCertVerificationError + + server = ws_ssl_server + + url = f"wss://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["ssl"] = True + + transport = WebsocketsTransport(url=url, **extra_args) + + with pytest.raises(SSLCertVerificationError) as exc_info: + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + await session.execute(query1) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + # Check client is disconnect here + assert transport.websocket is None + + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str])