1
1
import itertools
2
2
from contextlib import closing
3
- from typing import Any , Generator , List , Optional , Tuple , Union
3
+ from typing import Any , Callable , Dict , Generator , List , Optional , Tuple , Union
4
4
5
5
CRNL = b"\r \n "
6
6
@@ -34,11 +34,11 @@ def __init__(self, code: str, value: str) -> None:
34
34
def __repr__ (self ) -> str :
35
35
return f"ErrorString({ self .code !r} , { super ().__repr__ ()} )"
36
36
37
- def __str__ (self ):
37
+ def __str__ (self ) -> str :
38
38
return f"{ self .code } { super ().__str__ ()} "
39
39
40
40
41
- class PushData (list ):
41
+ class PushData (List [ Any ] ):
42
42
"""
43
43
A special type of list indicating data from a push response
44
44
"""
@@ -47,7 +47,7 @@ def __repr__(self) -> str:
47
47
return f"PushData({ super ().__repr__ ()} )"
48
48
49
49
50
- class Attribute (dict ):
50
+ class Attribute (Dict [ Any , Any ] ):
51
51
"""
52
52
A special type of map indicating data from a attribute response
53
53
"""
@@ -62,7 +62,7 @@ class RespEncoder:
62
62
"""
63
63
64
64
def __init__ (
65
- self , protocol : int = 2 , encoding : str = "utf-8" , errorhander = "strict"
65
+ self , protocol : int = 2 , encoding : str = "utf-8" , errorhander : str = "strict"
66
66
) -> None :
67
67
self .protocol = protocol
68
68
self .encoding = encoding
@@ -248,7 +248,7 @@ def parse(
248
248
rest += incoming
249
249
string = self .decode_bytes (rest [: (count + 4 )])
250
250
if string [3 ] != ":" :
251
- raise ValueError (f"Expected colon after hint, got { bulkstr [3 ]} " )
251
+ raise ValueError (f"Expected colon after hint, got { string [3 ]} " )
252
252
hint = string [:3 ]
253
253
string = string [4 : (count + 4 )]
254
254
yield VerbatimStr (string , hint ), rest [expect :]
@@ -310,8 +310,8 @@ def parse(
310
310
# we decode them automatically
311
311
decoded = self .decode_bytes (arg )
312
312
assert isinstance (decoded , str )
313
- code , value = decoded .split (" " , 1 )
314
- yield ErrorStr (code , value ), rest
313
+ err , value = decoded .split (" " , 1 )
314
+ yield ErrorStr (err , value ), rest
315
315
316
316
elif code == b"!" : # resp3 error
317
317
count = int (arg )
@@ -323,8 +323,8 @@ def parse(
323
323
bulkstr = rest [:count ]
324
324
decoded = self .decode_bytes (bulkstr )
325
325
assert isinstance (decoded , str )
326
- code , value = decoded .split (" " , 1 )
327
- yield ErrorStr (code , value ), rest [expect :]
326
+ err , value = decoded .split (" " , 1 )
327
+ yield ErrorStr (err , value ), rest [expect :]
328
328
329
329
else :
330
330
raise ValueError (f"Unknown opcode '{ code .decode ()} '" )
@@ -427,26 +427,26 @@ class RespServer:
427
427
Accepts RESP commands and returns RESP responses.
428
428
"""
429
429
430
- handlers = {}
430
+ handlers : Dict [ str , Callable [..., Any ]] = {}
431
431
432
- def __init__ (self ):
432
+ def __init__ (self ) -> None :
433
433
self .protocol = 2
434
434
self .server_ver = self .get_server_version ()
435
- self .auth = []
435
+ self .auth : List [ Any ] = []
436
436
self .client_name = ""
437
437
438
438
# patchable methods for testing
439
439
440
- def get_server_version (self ):
440
+ def get_server_version (self ) -> int :
441
441
return 6
442
442
443
- def on_auth (self , auth ) :
443
+ def on_auth (self , auth : List [ Any ]) -> None :
444
444
pass
445
445
446
- def on_setname (self , name ) :
446
+ def on_setname (self , name : str ) -> None :
447
447
pass
448
448
449
- def on_protocol (self , proto ) :
449
+ def on_protocol (self , proto : int ) -> None :
450
450
pass
451
451
452
452
def command (self , cmd : Any ) -> bytes :
@@ -466,7 +466,7 @@ def _command(self, cmd: Any) -> Any:
466
466
467
467
return ErrorStr ("ERR" , "unknown command {cmd!r}" )
468
468
469
- def handle_auth (self , args ) :
469
+ def handle_auth (self , args : List [ Any ]) -> Union [ str , ErrorStr ] :
470
470
self .auth = args [:]
471
471
self .on_auth (self .auth )
472
472
expect = 2 if self .server_ver >= 6 else 1
@@ -476,21 +476,21 @@ def handle_auth(self, args):
476
476
477
477
handlers ["AUTH" ] = handle_auth
478
478
479
- def handle_client (self , args ) :
479
+ def handle_client (self , args : List [ Any ]) -> Union [ str , ErrorStr ] :
480
480
if args [0 ] == "SETNAME" :
481
481
return self .handle_setname (args [1 :])
482
482
return ErrorStr ("ERR" , "unknown subcommand or wrong number of arguments" )
483
483
484
484
handlers ["CLIENT" ] = handle_client
485
485
486
- def handle_setname (self , args ) :
486
+ def handle_setname (self , args : List [ Any ]) -> Union [ str , ErrorStr ] :
487
487
if len (args ) != 1 :
488
488
return ErrorStr ("ERR" , "wrong number of arguments" )
489
489
self .client_name = args [0 ]
490
490
self .on_setname (self .client_name )
491
491
return "OK"
492
492
493
- def handle_hello (self , args ) :
493
+ def handle_hello (self , args : List [ Any ]) -> Union [ ErrorStr , Dict [ str , Any ]] :
494
494
if self .server_ver < 6 :
495
495
return ErrorStr ("ERR" , "unknown command 'HELLO'" )
496
496
proto = self .protocol
@@ -507,14 +507,14 @@ def handle_hello(self, args):
507
507
auth_args = args [:2 ]
508
508
args = args [2 :]
509
509
res = self .handle_auth (auth_args )
510
- if res != "OK" :
510
+ if isinstance ( res , ErrorStr ) :
511
511
return res
512
512
continue
513
513
if cmd == "SETNAME" :
514
514
setname_args = args [:1 ]
515
515
args = args [1 :]
516
516
res = self .handle_setname (setname_args )
517
- if res != "OK" :
517
+ if isinstance ( res , ErrorStr ) :
518
518
return res
519
519
continue
520
520
return ErrorStr ("ERR" , "unknown subcommand or wrong number of arguments" )
0 commit comments