Skip to content

Commit a4b059c

Browse files
liuchongauvipy
andauthored
Fix passing host to in headers (#268)
* Fix passing host to in headers * Update ws4py/client/__init__.py Signed-off-by: Asif Saif Uddin <[email protected]> * Add unit test for passing host to in headers --------- Signed-off-by: Asif Saif Uddin <[email protected]> Co-authored-by: Asif Saif Uddin <[email protected]>
1 parent 6c8bc76 commit a4b059c

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

test/test_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,15 @@ def test_parse_wss_scheme_with_query_string(self):
106106
self.assertEqual(c.resource, "/?token=value")
107107
self.assertEqual(c.bind_addr, ("127.0.0.1", 443))
108108

109+
def test_overriding_host_from_headers(self):
110+
c = WebSocketBaseClient(url="wss://127.0.0.1", headers=[("Host", "example123.com")])
111+
self.assertEqual(c.host, "127.0.0.1")
112+
self.assertEqual(c.port, 443)
113+
self.assertEqual(c.bind_addr, ("127.0.0.1", 443))
114+
for h in c.handshake_headers:
115+
if h[0].lower() == "host":
116+
self.assertEqual(h[1], "example123.com")
117+
109118
@patch('ws4py.client.socket')
110119
def test_connect_and_close(self, sock):
111120

ws4py/client/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def handshake_headers(self):
261261
handshake.
262262
"""
263263
headers = [
264-
('Host', '%s:%s' % (self.host, self.port)),
265264
('Connection', 'Upgrade'),
266265
('Upgrade', 'websocket'),
267266
('Sec-WebSocket-Key', self.key.decode('utf-8')),
@@ -274,6 +273,12 @@ def handshake_headers(self):
274273
if self.extra_headers:
275274
headers.extend(self.extra_headers)
276275

276+
# keep old logic if no overriding Host in headers
277+
if not any(x for x in headers if x[0].lower() == 'host') and \
278+
'host' not in self.exclude_headers:
279+
headers.append(('Host', '%s:%s' % (self.host, self.port)))
280+
281+
277282
if not any(x for x in headers if x[0].lower() == 'origin') and \
278283
'origin' not in self.exclude_headers:
279284

0 commit comments

Comments
 (0)