diff --git a/proxy/core/connection/client.py b/proxy/core/connection/client.py index f241c56a06..82f8194795 100644 --- a/proxy/core/connection/client.py +++ b/proxy/core/connection/client.py @@ -42,7 +42,8 @@ def connection(self) -> TcpOrTlsSocket: def wrap(self, keyfile: str, certfile: str) -> None: self.connection.setblocking(True) self.flush() - self._conn = ssl.wrap_socket( + ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) + self._conn = ssl_context.wrap_socket( self.connection, server_side=True, certfile=certfile, diff --git a/tests/http/proxy/test_http_proxy_tls_interception.py b/tests/http/proxy/test_http_proxy_tls_interception.py index 654bbc5fcd..699cc22cfc 100644 --- a/tests/http/proxy/test_http_proxy_tls_interception.py +++ b/tests/http/proxy/test_http_proxy_tls_interception.py @@ -62,9 +62,9 @@ async def test_e2e(self, mocker: MockerFixture) -> None: self.mock_ssl_context.return_value.wrap_socket.return_value = upstream_tls_sock # Used for client wrapping - self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') + self.mock_ssl_wrap = mocker.patch('ssl.SSLContext') client_tls_sock = mock.MagicMock(spec=ssl.SSLSocket) - self.mock_ssl_wrap.return_value = client_tls_sock + self.mock_ssl_wrap.return_value.wrap_socket.return_value = client_tls_sock plain_connection = mock.MagicMock(spec=socket.socket) @@ -251,6 +251,9 @@ async def asyncReturn(val: T) -> T: ) assert self.flags.ca_cert_dir is not None self.mock_ssl_wrap.assert_called_with( + protocol=ssl.PROTOCOL_TLS_CLIENT, + ) + self.mock_ssl_wrap.return_value.wrap_socket.assert_called_with( self._conn, server_side=True, keyfile=self.flags.ca_signing_key_file, diff --git a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py index 3d8d6a28f4..a0a05b61f8 100644 --- a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py +++ b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py @@ -46,7 +46,7 @@ def _setUp(self, request: Any, mocker: MockerFixture) -> None: 'proxy.http.proxy.server.TcpServerConnection', ) self.mock_ssl_context = mocker.patch('ssl.create_default_context') - self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') + self.mock_ssl_wrap = mocker.patch('ssl.SSLContext') self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True @@ -82,7 +82,7 @@ def _setUp(self, request: Any, mocker: MockerFixture) -> None: self.server_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection self.client_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) - self.mock_ssl_wrap.return_value = self.client_ssl_connection + self.mock_ssl_wrap.return_value.wrap_socket.return_value = self.client_ssl_connection def has_buffer() -> bool: return cast(bool, self.server.queue.called)