Skip to content

Commit 2a91ca1

Browse files
dpkpjeffwidman
authored andcommitted
Synchronize puts to KafkaConsumer protocol buffer during async sends
1 parent 8c07925 commit 2a91ca1

File tree

2 files changed

+60
-25
lines changed

2 files changed

+60
-25
lines changed

kafka/conn.py

+36-21
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import socket
1818
import struct
1919
import sys
20+
import threading
2021
import time
2122

2223
from kafka.vendor import six
@@ -220,7 +221,6 @@ def __init__(self, host, port, afi, **configs):
220221
self.afi = afi
221222
self._sock_afi = afi
222223
self._sock_addr = None
223-
self.in_flight_requests = collections.deque()
224224
self._api_versions = None
225225

226226
self.config = copy.copy(self.DEFAULT_CONFIG)
@@ -255,6 +255,20 @@ def __init__(self, host, port, afi, **configs):
255255
assert gssapi is not None, 'GSSAPI lib not available'
256256
assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl'
257257

258+
# This is not a general lock / this class is not generally thread-safe yet
259+
# However, to avoid pushing responsibility for maintaining
260+
# per-connection locks to the upstream client, we will use this lock to
261+
# make sure that access to the protocol buffer is synchronized
262+
# when sends happen on multiple threads
263+
self._lock = threading.Lock()
264+
265+
# the protocol parser instance manages actual tracking of the
266+
# sequence of in-flight requests to responses, which should
267+
# function like a FIFO queue. For additional request data,
268+
# including tracking request futures and timestamps, we
269+
# can use a simple dictionary of correlation_id => request data
270+
self.in_flight_requests = dict()
271+
258272
self._protocol = KafkaProtocol(
259273
client_id=self.config['client_id'],
260274
api_version=self.config['api_version'])
@@ -729,7 +743,7 @@ def close(self, error=None):
729743
if error is None:
730744
error = Errors.Cancelled(str(self))
731745
while self.in_flight_requests:
732-
(_, future, _) = self.in_flight_requests.popleft()
746+
(_correlation_id, (future, _timestamp)) = self.in_flight_requests.popitem()
733747
future.failure(error)
734748
self.config['state_change_callback'](self)
735749

@@ -747,23 +761,22 @@ def send(self, request, blocking=True):
747761
def _send(self, request, blocking=True):
748762
assert self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED)
749763
future = Future()
750-
correlation_id = self._protocol.send_request(request)
751-
752-
# Attempt to replicate behavior from prior to introduction of
753-
# send_pending_requests() / async sends
754-
if blocking:
755-
error = self.send_pending_requests()
756-
if isinstance(error, Exception):
757-
future.failure(error)
758-
return future
764+
with self._lock:
765+
correlation_id = self._protocol.send_request(request)
759766

760767
log.debug('%s Request %d: %s', self, correlation_id, request)
761768
if request.expect_response():
762769
sent_time = time.time()
763-
ifr = (correlation_id, future, sent_time)
764-
self.in_flight_requests.append(ifr)
770+
assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
771+
self.in_flight_requests[correlation_id] = (future, sent_time)
765772
else:
766773
future.success(None)
774+
775+
# Attempt to replicate behavior from prior to introduction of
776+
# send_pending_requests() / async sends
777+
if blocking:
778+
self.send_pending_requests()
779+
767780
return future
768781

769782
def send_pending_requests(self):
@@ -818,8 +831,12 @@ def recv(self):
818831
return ()
819832

820833
# augment respones w/ correlation_id, future, and timestamp
821-
for i, response in enumerate(responses):
822-
(correlation_id, future, timestamp) = self.in_flight_requests.popleft()
834+
for i, (correlation_id, response) in enumerate(responses):
835+
try:
836+
(future, timestamp) = self.in_flight_requests.pop(correlation_id)
837+
except KeyError:
838+
self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
839+
return ()
823840
latency_ms = (time.time() - timestamp) * 1000
824841
if self._sensors:
825842
self._sensors.request_time.record(latency_ms)
@@ -870,20 +887,18 @@ def _recv(self):
870887
self.close(e)
871888
return []
872889
else:
873-
return [resp for (_, resp) in responses] # drop correlation id
890+
return responses
874891

875892
def requests_timed_out(self):
876893
if self.in_flight_requests:
877-
(_, _, oldest_at) = self.in_flight_requests[0]
894+
get_timestamp = lambda v: v[1]
895+
oldest_at = min(map(get_timestamp,
896+
self.in_flight_requests.values()))
878897
timeout = self.config['request_timeout_ms'] / 1000.0
879898
if time.time() >= oldest_at + timeout:
880899
return True
881900
return False
882901

883-
def _next_correlation_id(self):
884-
self._correlation_id = (self._correlation_id + 1) % 2**31
885-
return self._correlation_id
886-
887902
def _handle_api_version_response(self, response):
888903
error_type = Errors.for_code(response.error_code)
889904
assert error_type is Errors.NoError, "API version check failed"

test/test_conn.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def test_send_connecting(conn):
112112
def test_send_max_ifr(conn):
113113
conn.state = ConnectionStates.CONNECTED
114114
max_ifrs = conn.config['max_in_flight_requests_per_connection']
115-
for _ in range(max_ifrs):
116-
conn.in_flight_requests.append('foo')
115+
for i in range(max_ifrs):
116+
conn.in_flight_requests[i] = 'foo'
117117
f = conn.send('foobar')
118118
assert f.failed() is True
119119
assert isinstance(f.exception, Errors.TooManyInFlightRequests)
@@ -170,9 +170,9 @@ def test_send_error(_socket, conn):
170170
def test_can_send_more(conn):
171171
assert conn.can_send_more() is True
172172
max_ifrs = conn.config['max_in_flight_requests_per_connection']
173-
for _ in range(max_ifrs):
173+
for i in range(max_ifrs):
174174
assert conn.can_send_more() is True
175-
conn.in_flight_requests.append('foo')
175+
conn.in_flight_requests[i] = 'foo'
176176
assert conn.can_send_more() is False
177177

178178

@@ -311,3 +311,23 @@ def test_relookup_on_failure():
311311
assert conn._sock_afi == afi2
312312
assert conn._sock_addr == sockaddr2
313313
conn.close()
314+
315+
316+
def test_requests_timed_out(conn):
317+
with mock.patch("time.time", return_value=0):
318+
# No in-flight requests, not timed out
319+
assert not conn.requests_timed_out()
320+
321+
# Single request, timestamp = now (0)
322+
conn.in_flight_requests[0] = ('foo', 0)
323+
assert not conn.requests_timed_out()
324+
325+
# Add another request w/ timestamp > request_timeout ago
326+
request_timeout = conn.config['request_timeout_ms']
327+
expired_timestamp = 0 - request_timeout - 1
328+
conn.in_flight_requests[1] = ('bar', expired_timestamp)
329+
assert conn.requests_timed_out()
330+
331+
# Drop the expired request and we should be good to go again
332+
conn.in_flight_requests.pop(1)
333+
assert not conn.requests_timed_out()

0 commit comments

Comments
 (0)