Skip to content

Commit 18e881b

Browse files
committed
add connection handshake tests for non-async version
1 parent a2944b0 commit 18e881b

File tree

1 file changed

+89
-1
lines changed

1 file changed

+89
-1
lines changed

tests/test_connect.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@
33
import socketserver
44
import ssl
55
import threading
6+
from unittest.mock import patch
67

78
import pytest
8-
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection
9+
from redis.connection import (
10+
Connection,
11+
ResponseError,
12+
SSLConnection,
13+
UnixDomainSocketConnection,
14+
)
915

1016
from . import resp
1117
from .ssl_utils import get_ssl_filename
@@ -55,6 +61,88 @@ def test_tcp_ssl_connect(tcp_address):
5561
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
5662

5763

64+
@pytest.mark.parametrize(
65+
("use_server_ver", "use_protocol", "use_auth", "use_client_name"),
66+
[
67+
(5, 2, False, True),
68+
(5, 2, True, True),
69+
(5, 3, True, True),
70+
(6, 2, False, True),
71+
(6, 2, True, True),
72+
(6, 3, False, False),
73+
(6, 3, True, False),
74+
(6, 3, False, True),
75+
(6, 3, True, True),
76+
],
77+
)
78+
# @pytest.mark.parametrize("use_protocol", [2, 3])
79+
# @pytest.mark.parametrize("use_auth", [False, True])
80+
def test_tcp_auth(tcp_address, use_protocol, use_auth, use_server_ver, use_client_name):
81+
"""
82+
Test that various initial handshake cases are handled correctly by the client
83+
"""
84+
got_auth = []
85+
got_protocol = None
86+
got_name = None
87+
88+
def on_auth(self, auth):
89+
got_auth[:] = auth
90+
91+
def on_protocol(self, proto):
92+
nonlocal got_protocol
93+
got_protocol = proto
94+
95+
def on_setname(self, name):
96+
nonlocal got_name
97+
got_name = name
98+
99+
def get_server_version(self):
100+
return use_server_ver
101+
102+
if use_auth:
103+
auth_args = {"username": "myuser", "password": "mypassword"}
104+
else:
105+
auth_args = {}
106+
got_protocol = None
107+
host, port = tcp_address
108+
conn = Connection(
109+
host=host,
110+
port=port,
111+
client_name=_CLIENT_NAME if use_client_name else None,
112+
socket_timeout=10,
113+
protocol=use_protocol,
114+
**auth_args,
115+
)
116+
try:
117+
with patch.multiple(
118+
resp.RespServer,
119+
on_auth=on_auth,
120+
get_server_version=get_server_version,
121+
on_protocol=on_protocol,
122+
on_setname=on_setname,
123+
):
124+
if use_server_ver < 6 and use_protocol > 2:
125+
with pytest.raises(ResponseError):
126+
_assert_connect(conn, tcp_address)
127+
return
128+
129+
_assert_connect(conn, tcp_address)
130+
if use_protocol == 3:
131+
assert got_protocol == use_protocol
132+
if use_auth:
133+
if use_server_ver < 6:
134+
assert got_auth == ["mypassword"]
135+
else:
136+
assert got_auth == ["myuser", "mypassword"]
137+
138+
if use_client_name:
139+
assert got_name == _CLIENT_NAME
140+
else:
141+
assert got_name is None
142+
finally:
143+
conn.disconnect()
144+
145+
58146
def _assert_connect(conn, server_address, certfile=None, keyfile=None):
59147
if isinstance(server_address, str):
60148
if not _RedisUDSServer:

0 commit comments

Comments
 (0)