@@ -323,3 +323,40 @@ async def test_headers_passed_to_websocket_connect(self, mock_connect):
323
323
)
324
324
325
325
mock_handle .assert_has_calls ([call ({"data" : {"messageAdded" : "one" }})])
326
+
327
+ @patch ("websockets.connect" )
328
+ async def test_init_payload_passed_in_init_message (self , mock_connect ):
329
+ """Subsribe a GraphQL subscription."""
330
+ mock_websocket = mock_connect .return_value .__aenter__ .return_value
331
+ mock_websocket .send = AsyncMock ()
332
+ mock_websocket .__aiter__ .return_value = [
333
+ '{"type": "connection_init", "payload": '
334
+ '{"init": "this is the init_payload"}}' ,
335
+ '{"type": "data", "id": "1", "payload": {"data": {"messageAdded": "one"}}}' ,
336
+ ]
337
+ expected_endpoint = "ws://www.test-api.com/graphql"
338
+ client = GraphqlClient (endpoint = expected_endpoint )
339
+
340
+ query = """
341
+ subscription onMessageAdded {
342
+ messageAdded
343
+ }
344
+ """
345
+ init_payload = '{"init": "this is the init_payload"}'
346
+
347
+ mock_handle = MagicMock ()
348
+
349
+ await client .subscribe (
350
+ query = query , handle = mock_handle , init_payload = init_payload
351
+ )
352
+
353
+ mock_connect .assert_called_with (
354
+ expected_endpoint , subprotocols = ["graphql-ws" ], extra_headers = {}
355
+ )
356
+
357
+ mock_handle .assert_has_calls (
358
+ [
359
+ call ({"init" : "this is the init_payload" }),
360
+ call ({"data" : {"messageAdded" : "one" }}),
361
+ ]
362
+ )
0 commit comments