Skip to content

Commit 8b0d685

Browse files
committed
Saving response_headers in WebsocketsTransport + test
1 parent 544a7d5 commit 8b0d685

File tree

3 files changed

+41
-29
lines changed

3 files changed

+41
-29
lines changed

gql/transport/websockets_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import websockets
1010
from graphql import DocumentNode, ExecutionResult
1111
from websockets.client import WebSocketClientProtocol
12-
from websockets.datastructures import HeadersLike
12+
from websockets.datastructures import Headers, HeadersLike
1313
from websockets.exceptions import ConnectionClosed
1414
from websockets.typing import Data, Subprotocol
1515

@@ -169,6 +169,8 @@ def __init__(
169169
# The list of supported subprotocols should be defined in the subclass
170170
self.supported_subprotocols: List[Subprotocol] = []
171171

172+
self.response_headers: Optional[Headers] = None
173+
172174
async def _initialize(self):
173175
"""Hook to send the initialization messages after the connection
174176
and potentially wait for the backend ack.
@@ -495,6 +497,8 @@ async def connect(self) -> None:
495497

496498
self.websocket = cast(WebSocketClientProtocol, self.websocket)
497499

500+
self.response_headers = self.websocket.response_headers
501+
498502
# Run the after_connect hook of the subclass
499503
await self._after_connect()
500504

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ async def start(self, handler, extra_serve_args=None):
174174
self.testcert, ssl_context = get_localhost_ssl_context()
175175
extra_serve_args["ssl"] = ssl_context
176176

177+
# Adding dummy response headers
178+
extra_serve_args["extra_headers"] = {"dummy": "test1234"}
179+
177180
# Start a server with a random open port
178181
self.start_server = websockets.server.serve(
179182
handler, "127.0.0.1", 0, **extra_serve_args

tests/test_websocket_query.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import ssl
44
import sys
5-
from typing import Dict
5+
from typing import Dict, Mapping
66

77
import pytest
88

@@ -58,12 +58,12 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server):
5858
url = f"ws://{server.hostname}:{server.port}/graphql"
5959
print(f"url = {url}")
6060

61-
sample_transport = WebsocketsTransport(url=url)
61+
transport = WebsocketsTransport(url=url)
6262

63-
async with Client(transport=sample_transport) as session:
63+
async with Client(transport=transport) as session:
6464

6565
assert isinstance(
66-
sample_transport.websocket, websockets.client.WebSocketClientProtocol
66+
transport.websocket, websockets.client.WebSocketClientProtocol
6767
)
6868

6969
query1 = gql(query1_str)
@@ -80,8 +80,13 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server):
8080

8181
assert africa["code"] == "AF"
8282

83+
# Checking response headers are saved in the transport
84+
assert hasattr(transport, "response_headers")
85+
assert isinstance(transport.response_headers, Mapping)
86+
assert transport.response_headers["dummy"] == "test1234"
87+
8388
# Check client is disconnect here
84-
assert sample_transport.websocket is None
89+
assert transport.websocket is None
8590

8691

8792
@pytest.mark.asyncio
@@ -98,12 +103,12 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server):
98103
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
99104
ssl_context.load_verify_locations(ws_ssl_server.testcert)
100105

101-
sample_transport = WebsocketsTransport(url=url, ssl=ssl_context)
106+
transport = WebsocketsTransport(url=url, ssl=ssl_context)
102107

103-
async with Client(transport=sample_transport) as session:
108+
async with Client(transport=transport) as session:
104109

105110
assert isinstance(
106-
sample_transport.websocket, websockets.client.WebSocketClientProtocol
111+
transport.websocket, websockets.client.WebSocketClientProtocol
107112
)
108113

109114
query1 = gql(query1_str)
@@ -121,7 +126,7 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server):
121126
assert africa["code"] == "AF"
122127

123128
# Check client is disconnect here
124-
assert sample_transport.websocket is None
129+
assert transport.websocket is None
125130

126131

127132
@pytest.mark.asyncio
@@ -301,19 +306,19 @@ async def test_websocket_multiple_connections_in_series(event_loop, server):
301306
url = f"ws://{server.hostname}:{server.port}/graphql"
302307
print(f"url = {url}")
303308

304-
sample_transport = WebsocketsTransport(url=url)
309+
transport = WebsocketsTransport(url=url)
305310

306-
async with Client(transport=sample_transport) as session:
311+
async with Client(transport=transport) as session:
307312
await assert_client_is_working(session)
308313

309314
# Check client is disconnect here
310-
assert sample_transport.websocket is None
315+
assert transport.websocket is None
311316

312-
async with Client(transport=sample_transport) as session:
317+
async with Client(transport=transport) as session:
313318
await assert_client_is_working(session)
314319

315320
# Check client is disconnect here
316-
assert sample_transport.websocket is None
321+
assert transport.websocket is None
317322

318323

319324
@pytest.mark.asyncio
@@ -325,8 +330,8 @@ async def test_websocket_multiple_connections_in_parallel(event_loop, server):
325330
print(f"url = {url}")
326331

327332
async def task_coro():
328-
sample_transport = WebsocketsTransport(url=url)
329-
async with Client(transport=sample_transport) as session:
333+
transport = WebsocketsTransport(url=url)
334+
async with Client(transport=transport) as session:
330335
await assert_client_is_working(session)
331336

332337
task1 = asyncio.ensure_future(task_coro())
@@ -345,12 +350,12 @@ async def test_websocket_trying_to_connect_to_already_connected_transport(
345350
url = f"ws://{server.hostname}:{server.port}/graphql"
346351
print(f"url = {url}")
347352

348-
sample_transport = WebsocketsTransport(url=url)
349-
async with Client(transport=sample_transport) as session:
353+
transport = WebsocketsTransport(url=url)
354+
async with Client(transport=transport) as session:
350355
await assert_client_is_working(session)
351356

352357
with pytest.raises(TransportAlreadyConnected):
353-
async with Client(transport=sample_transport):
358+
async with Client(transport=transport):
354359
pass
355360

356361

@@ -395,9 +400,9 @@ async def test_websocket_connect_success_with_authentication_in_connection_init(
395400

396401
init_payload = {"Authorization": 12345}
397402

398-
sample_transport = WebsocketsTransport(url=url, init_payload=init_payload)
403+
transport = WebsocketsTransport(url=url, init_payload=init_payload)
399404

400-
async with Client(transport=sample_transport) as session:
405+
async with Client(transport=transport) as session:
401406

402407
query1 = gql(query_str)
403408

@@ -428,10 +433,10 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init(
428433
url = f"ws://{server.hostname}:{server.port}/graphql"
429434
print(f"url = {url}")
430435

431-
sample_transport = WebsocketsTransport(url=url, init_payload=init_payload)
436+
transport = WebsocketsTransport(url=url, init_payload=init_payload)
432437

433438
with pytest.raises(TransportServerError):
434-
async with Client(transport=sample_transport) as session:
439+
async with Client(transport=transport) as session:
435440
query1 = gql(query_str)
436441

437442
await session.execute(query1)
@@ -444,9 +449,9 @@ def test_websocket_execute_sync(server):
444449
url = f"ws://{server.hostname}:{server.port}/graphql"
445450
print(f"url = {url}")
446451

447-
sample_transport = WebsocketsTransport(url=url)
452+
transport = WebsocketsTransport(url=url)
448453

449-
client = Client(transport=sample_transport)
454+
client = Client(transport=transport)
450455

451456
query1 = gql(query1_str)
452457

@@ -476,7 +481,7 @@ def test_websocket_execute_sync(server):
476481
assert africa["code"] == "AF"
477482

478483
# Check client is disconnect here
479-
assert sample_transport.websocket is None
484+
assert transport.websocket is None
480485

481486

482487
@pytest.mark.asyncio
@@ -487,11 +492,11 @@ async def test_websocket_add_extra_parameters_to_connect(event_loop, server):
487492
url = f"ws://{server.hostname}:{server.port}/graphql"
488493

489494
# Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions
490-
sample_transport = WebsocketsTransport(url=url, connect_args={"max_size": 2 ** 21})
495+
transport = WebsocketsTransport(url=url, connect_args={"max_size": 2 ** 21})
491496

492497
query = gql(query1_str)
493498

494-
async with Client(transport=sample_transport) as session:
499+
async with Client(transport=transport) as session:
495500
await session.execute(query)
496501

497502

0 commit comments

Comments
 (0)