@@ -348,22 +348,34 @@ def _run_reader(*args: Any) -> List[Event]:
348
348
return normalize_data_events (events )
349
349
350
350
351
- def t_body_reader (thunk : Any , data : bytes , expected : Any , do_eof : bool = False ) -> None :
351
+ def t_body_reader (thunk : Any , data : bytes , expected : list , do_eof : bool = False ) -> None :
352
352
# Simple: consume whole thing
353
353
print ("Test 1" )
354
354
buf = makebuf (data )
355
- assert _run_reader (thunk (), buf , do_eof ) == expected
355
+ try :
356
+ assert _run_reader (thunk (), buf , do_eof ) == expected
357
+ except LocalProtocolError :
358
+ if LocalProtocolError in expected :
359
+ pass
360
+ else :
361
+ raise
356
362
357
363
# Incrementally growing buffer
358
364
print ("Test 2" )
359
365
reader = thunk ()
360
366
buf = ReceiveBuffer ()
361
367
events = []
362
- for i in range (len (data )):
363
- events += _run_reader (reader , buf , False )
364
- buf += data [i : i + 1 ]
365
- events += _run_reader (reader , buf , do_eof )
366
- assert normalize_data_events (events ) == expected
368
+ try :
369
+ for i in range (len (data )):
370
+ events += _run_reader (reader , buf , False )
371
+ buf += data [i : i + 1 ]
372
+ events += _run_reader (reader , buf , do_eof )
373
+ assert normalize_data_events (events ) == expected
374
+ except LocalProtocolError :
375
+ if LocalProtocolError in expected :
376
+ pass
377
+ else :
378
+ raise
367
379
368
380
is_complete = any (type (event ) is EndOfMessage for event in expected )
369
381
if is_complete and not do_eof :
@@ -424,14 +436,12 @@ def test_ChunkedReader() -> None:
424
436
)
425
437
426
438
# refuses arbitrarily long chunk integers
427
- with pytest .raises (LocalProtocolError ):
428
- # Technically this is legal HTTP/1.1, but we refuse to process chunk
429
- # sizes that don't fit into 20 characters of hex
430
- t_body_reader (ChunkedReader , b"9" * 100 + b"\r \n xxx" , [Data (data = b"xxx" )])
439
+ # Technically this is legal HTTP/1.1, but we refuse to process chunk
440
+ # sizes that don't fit into 20 characters of hex
441
+ t_body_reader (ChunkedReader , b"9" * 100 + b"\r \n xxx" , [LocalProtocolError ])
431
442
432
443
# refuses garbage in the chunk count
433
- with pytest .raises (LocalProtocolError ):
434
- t_body_reader (ChunkedReader , b"10\x00 \r \n xxx" , None )
444
+ t_body_reader (ChunkedReader , b"10\x00 \r \n xxx" , [LocalProtocolError ])
435
445
436
446
# handles (and discards) "chunk extensions" omg wtf
437
447
t_body_reader (
@@ -445,10 +455,22 @@ def test_ChunkedReader() -> None:
445
455
446
456
t_body_reader (
447
457
ChunkedReader ,
448
- b"5 \r \n 01234\r \n " + b"0\r \n \r \n " ,
458
+ b"5 \t \r \n 01234\r \n " + b"0\r \n \r \n " ,
449
459
[Data (data = b"01234" ), EndOfMessage ()],
450
460
)
451
461
462
+ # Chunked encoding with bad chunk termination characters are refused. Originally we
463
+ # simply dropped the 2 bytes after a chunk, instead of validating that the bytes
464
+ # were \r\n -- so we would successfully decode the data below as b"xxxa". And
465
+ # apparently there are other HTTP processors that ignore the chunk length and just
466
+ # keep reading until they see \r\n, so they would decode it as b"xxx__1a". Any time
467
+ # two HTTP processors accept the same input but interpret it differently, there's a
468
+ # possibility of request smuggling shenanigans. So we now reject this.
469
+ t_body_reader (ChunkedReader , b"3\r \n xxx__1a\r \n " , [LocalProtocolError ])
470
+
471
+ # Confirm we check both bytes individually
472
+ t_body_reader (ChunkedReader , b"3\r \n xxx\r _1a\r \n " , [LocalProtocolError ])
473
+ t_body_reader (ChunkedReader , b"3\r \n xxx_\n 1a\r \n " , [LocalProtocolError ])
452
474
453
475
def test_ContentLengthWriter () -> None :
454
476
w = ContentLengthWriter (5 )
@@ -471,8 +493,8 @@ def test_ContentLengthWriter() -> None:
471
493
dowrite (w , EndOfMessage ())
472
494
473
495
w = ContentLengthWriter (5 )
474
- dowrite (w , Data (data = b"123" )) == b"123"
475
- dowrite (w , Data (data = b"45" )) == b"45"
496
+ assert dowrite (w , Data (data = b"123" )) == b"123"
497
+ assert dowrite (w , Data (data = b"45" )) == b"45"
476
498
with pytest .raises (LocalProtocolError ):
477
499
dowrite (w , EndOfMessage (headers = [("Etag" , "asdf" )]))
478
500
0 commit comments