17
17
import socket
18
18
import struct
19
19
import sys
20
+ import threading
20
21
import time
21
22
22
23
from kafka .vendor import six
@@ -220,7 +221,6 @@ def __init__(self, host, port, afi, **configs):
220
221
self .afi = afi
221
222
self ._sock_afi = afi
222
223
self ._sock_addr = None
223
- self .in_flight_requests = collections .deque ()
224
224
self ._api_versions = None
225
225
226
226
self .config = copy .copy (self .DEFAULT_CONFIG )
@@ -255,6 +255,20 @@ def __init__(self, host, port, afi, **configs):
255
255
assert gssapi is not None , 'GSSAPI lib not available'
256
256
assert self .config ['sasl_kerberos_service_name' ] is not None , 'sasl_kerberos_service_name required for GSSAPI sasl'
257
257
258
+ # This is not a general lock / this class is not generally thread-safe yet
259
+ # However, to avoid pushing responsibility for maintaining
260
+ # per-connection locks to the upstream client, we will use this lock to
261
+ # make sure that access to the protocol buffer is synchronized
262
+ # when sends happen on multiple threads
263
+ self ._lock = threading .Lock ()
264
+
265
+ # the protocol parser instance manages actual tracking of the
266
+ # sequence of in-flight requests to responses, which should
267
+ # function like a FIFO queue. For additional request data,
268
+ # including tracking request futures and timestamps, we
269
+ # can use a simple dictionary of correlation_id => request data
270
+ self .in_flight_requests = dict ()
271
+
258
272
self ._protocol = KafkaProtocol (
259
273
client_id = self .config ['client_id' ],
260
274
api_version = self .config ['api_version' ])
@@ -729,7 +743,7 @@ def close(self, error=None):
729
743
if error is None :
730
744
error = Errors .Cancelled (str (self ))
731
745
while self .in_flight_requests :
732
- (_ , future , _ ) = self .in_flight_requests .popleft ()
746
+ (_correlation_id , ( future , _timestamp )) = self .in_flight_requests .popitem ()
733
747
future .failure (error )
734
748
self .config ['state_change_callback' ](self )
735
749
@@ -747,23 +761,22 @@ def send(self, request, blocking=True):
747
761
def _send (self , request , blocking = True ):
748
762
assert self .state in (ConnectionStates .AUTHENTICATING , ConnectionStates .CONNECTED )
749
763
future = Future ()
750
- correlation_id = self ._protocol .send_request (request )
751
-
752
- # Attempt to replicate behavior from prior to introduction of
753
- # send_pending_requests() / async sends
754
- if blocking :
755
- error = self .send_pending_requests ()
756
- if isinstance (error , Exception ):
757
- future .failure (error )
758
- return future
764
+ with self ._lock :
765
+ correlation_id = self ._protocol .send_request (request )
759
766
760
767
log .debug ('%s Request %d: %s' , self , correlation_id , request )
761
768
if request .expect_response ():
762
769
sent_time = time .time ()
763
- ifr = ( correlation_id , future , sent_time )
764
- self .in_flight_requests . append ( ifr )
770
+ assert correlation_id not in self . in_flight_requests , 'Correlation ID already in-flight!'
771
+ self .in_flight_requests [ correlation_id ] = ( future , sent_time )
765
772
else :
766
773
future .success (None )
774
+
775
+ # Attempt to replicate behavior from prior to introduction of
776
+ # send_pending_requests() / async sends
777
+ if blocking :
778
+ self .send_pending_requests ()
779
+
767
780
return future
768
781
769
782
def send_pending_requests (self ):
@@ -818,8 +831,12 @@ def recv(self):
818
831
return ()
819
832
820
833
# augment respones w/ correlation_id, future, and timestamp
821
- for i , response in enumerate (responses ):
822
- (correlation_id , future , timestamp ) = self .in_flight_requests .popleft ()
834
+ for i , (correlation_id , response ) in enumerate (responses ):
835
+ try :
836
+ (future , timestamp ) = self .in_flight_requests .pop (correlation_id )
837
+ except KeyError :
838
+ self .close (Errors .KafkaConnectionError ('Received unrecognized correlation id' ))
839
+ return ()
823
840
latency_ms = (time .time () - timestamp ) * 1000
824
841
if self ._sensors :
825
842
self ._sensors .request_time .record (latency_ms )
@@ -870,20 +887,18 @@ def _recv(self):
870
887
self .close (e )
871
888
return []
872
889
else :
873
- return [ resp for ( _ , resp ) in responses ] # drop correlation id
890
+ return responses
874
891
875
892
def requests_timed_out (self ):
876
893
if self .in_flight_requests :
877
- (_ , _ , oldest_at ) = self .in_flight_requests [0 ]
894
+ get_timestamp = lambda v : v [1 ]
895
+ oldest_at = min (map (get_timestamp ,
896
+ self .in_flight_requests .values ()))
878
897
timeout = self .config ['request_timeout_ms' ] / 1000.0
879
898
if time .time () >= oldest_at + timeout :
880
899
return True
881
900
return False
882
901
883
- def _next_correlation_id (self ):
884
- self ._correlation_id = (self ._correlation_id + 1 ) % 2 ** 31
885
- return self ._correlation_id
886
-
887
902
def _handle_api_version_response (self , response ):
888
903
error_type = Errors .for_code (response .error_code )
889
904
assert error_type is Errors .NoError , "API version check failed"
0 commit comments