Skip to content

Move callback processing from BrokerConnection to KafkaClient #1258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions kafka/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def _send_broker_unaware_request(self, payloads, encoder_fn, decoder_fn):

# Block
while not future.is_done:
conn.recv()
for r, f in conn.recv():
f.success(r)

if future.failed():
log.error("Request failed: %s", future.exception)
Expand Down Expand Up @@ -288,7 +289,8 @@ def failed_payloads(payloads):

if not future.is_done:
conn, _ = connections_by_future[future]
conn.recv()
for r, f in conn.recv():
f.success(r)
continue

_, broker = connections_by_future.pop(future)
Expand Down Expand Up @@ -352,8 +354,6 @@ def _send_consumer_aware_request(self, group, payloads, encoder_fn, decoder_fn):
try:
host, port, afi = get_ip_port_afi(broker.host)
conn = self._get_conn(host, broker.port, afi)
conn.send(request_id, request)

except ConnectionError as e:
log.warning('ConnectionError attempting to send request %s '
'to server %s: %s', request_id, broker, e)
Expand All @@ -365,6 +365,11 @@ def _send_consumer_aware_request(self, group, payloads, encoder_fn, decoder_fn):
# No exception, try to get response
else:

future = conn.send(request_id, request)
while not future.is_done:
for r, f in conn.recv():
f.success(r)

# decoder_fn=None signal that the server is expected to not
# send a response. This probably only applies to
# ProduceRequest w/ acks = 0
Expand All @@ -376,18 +381,17 @@ def _send_consumer_aware_request(self, group, payloads, encoder_fn, decoder_fn):
responses[topic_partition] = None
return []

try:
response = conn.recv(request_id)
except ConnectionError as e:
log.warning('ConnectionError attempting to receive a '
if future.failed():
log.warning('Error attempting to receive a '
'response to request %s from server %s: %s',
request_id, broker, e)
request_id, broker, future.exception)

for payload in payloads:
topic_partition = (payload.topic, payload.partition)
responses[topic_partition] = FailedPayloadsError(payload)

else:
response = future.value
_resps = []
for payload_response in decoder_fn(response):
topic_partition = (payload_response.topic,
Expand Down
32 changes: 25 additions & 7 deletions kafka/client_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import, division

import collections
import copy
import functools
import heapq
Expand Down Expand Up @@ -204,6 +205,11 @@ def __init__(self, **configs):
self._wake_r, self._wake_w = socket.socketpair()
self._wake_r.setblocking(False)
self._wake_lock = threading.Lock()

# when requests complete, they are transferred to this queue prior to
# invocation.
self._pending_completion = collections.deque()

self._selector.register(self._wake_r, selectors.EVENT_READ)
self._idle_expiry_manager = IdleConnectionManager(self.config['connections_max_idle_ms'])
self._closed = False
Expand Down Expand Up @@ -254,7 +260,8 @@ def _bootstrap(self, hosts):
future = bootstrap.send(metadata_request)
while not future.is_done:
self._selector.select(1)
bootstrap.recv()
for r, f in bootstrap.recv():
f.success(r)
if future.failed():
bootstrap.close()
continue
Expand Down Expand Up @@ -512,7 +519,9 @@ def poll(self, timeout_ms=None, future=None, delayed_tasks=True):
Returns:
list: responses received (can be empty)
"""
if timeout_ms is None:
if future is not None:
timeout_ms = 100
elif timeout_ms is None:
timeout_ms = self.config['request_timeout_ms']

responses = []
Expand Down Expand Up @@ -551,7 +560,9 @@ def poll(self, timeout_ms=None, future=None, delayed_tasks=True):
self.config['request_timeout_ms'])
timeout = max(0, timeout / 1000.0) # avoid negative timeouts

responses.extend(self._poll(timeout))
self._poll(timeout)

responses.extend(self._fire_pending_completed_requests())

# If all we had was a timeout (future is None) - only do one poll
# If we do have a future, we keep looping until it is done
Expand All @@ -561,7 +572,7 @@ def poll(self, timeout_ms=None, future=None, delayed_tasks=True):
return responses

def _poll(self, timeout):
responses = []
"""Returns list of (response, future) tuples"""
processed = set()

start_select = time.time()
Expand Down Expand Up @@ -600,14 +611,14 @@ def _poll(self, timeout):
continue

self._idle_expiry_manager.update(conn.node_id)
responses.extend(conn.recv()) # Note: conn.recv runs callbacks / errbacks
self._pending_completion.extend(conn.recv())

# Check for additional pending SSL bytes
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
# TODO: optimize
for conn in self._conns.values():
if conn not in processed and conn.connected() and conn._sock.pending():
responses.extend(conn.recv())
self._pending_completion.extend(conn.recv())

for conn in six.itervalues(self._conns):
if conn.requests_timed_out():
Expand All @@ -621,7 +632,6 @@ def _poll(self, timeout):
self._sensors.io_time.record((time.time() - end_select) * 1000000000)

self._maybe_close_oldest_connection()
return responses

def in_flight_request_count(self, node_id=None):
"""Get the number of in-flight requests for a node or all nodes.
Expand All @@ -640,6 +650,14 @@ def in_flight_request_count(self, node_id=None):
else:
return sum([len(conn.in_flight_requests) for conn in self._conns.values()])

def _fire_pending_completed_requests(self):
responses = []
while self._pending_completion:
response, future = self._pending_completion.popleft()
future.success(response)
responses.append(response)
return responses

def least_loaded_node(self):
"""Choose the node with fewest outstanding requests, with fallbacks.

Expand Down
39 changes: 25 additions & 14 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
import errno
import logging
from random import shuffle, uniform

# selectors in stdlib as of py3.4
try:
import selectors # pylint: disable=import-error
except ImportError:
# vendored backport module
from .vendor import selectors34 as selectors

import socket
import struct
import sys
Expand Down Expand Up @@ -138,6 +146,9 @@ class BrokerConnection(object):
api_version_auto_timeout_ms (int): number of milliseconds to throw a
timeout exception from the constructor when checking the broker
api version. Only applies if api_version is None
selector (selectors.BaseSelector): Provide a specific selector
implementation to use for I/O multiplexing.
Default: selectors.DefaultSelector
state_change_callback (callable): function to be called when the
connection state changes from CONNECTING to CONNECTED etc.
metrics (kafka.metrics.Metrics): Optionally provide a metrics
Expand Down Expand Up @@ -173,6 +184,7 @@ class BrokerConnection(object):
'ssl_crlfile': None,
'ssl_password': None,
'api_version': (0, 8, 2), # default to most restrictive
'selector': selectors.DefaultSelector,
'state_change_callback': lambda conn: True,
'metrics': None,
'metric_group_prefix': '',
Expand Down Expand Up @@ -705,7 +717,7 @@ def can_send_more(self):
def recv(self):
"""Non-blocking network receive.

Return response if available
Return list of (response, future)
"""
if not self.connected() and not self.state is ConnectionStates.AUTHENTICATING:
log.warning('%s cannot recv: socket not connected', self)
Expand All @@ -728,17 +740,16 @@ def recv(self):
self.config['request_timeout_ms']))
return ()

for response in responses:
# augment respones w/ correlation_id, future, and timestamp
for i in range(len(responses)):
(correlation_id, future, timestamp) = self.in_flight_requests.popleft()
if isinstance(response, Errors.KafkaError):
self.close(response)
break

latency_ms = (time.time() - timestamp) * 1000
if self._sensors:
self._sensors.request_time.record((time.time() - timestamp) * 1000)
self._sensors.request_time.record(latency_ms)

log.debug('%s Response %d: %s', self, correlation_id, response)
future.success(response)
response = responses[i]
log.debug('%s Response %d (%s ms): %s', self, correlation_id, latency_ms, response)
responses[i] = (response, future)

return responses

Expand Down Expand Up @@ -900,12 +911,12 @@ def connect():
# request was unrecognized
mr = self.send(MetadataRequest[0]([]))

if self._sock:
self._sock.setblocking(True)
selector = self.config['selector']()
selector.register(self._sock, selectors.EVENT_READ)
while not (f.is_done and mr.is_done):
self.recv()
if self._sock:
self._sock.setblocking(False)
for response, future in self.recv():
future.success(response)
selector.select(1)

if f.succeeded():
if isinstance(request, ApiVersionRequest[0]):
Expand Down
3 changes: 2 additions & 1 deletion test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def mock_conn(conn, success=True):
else:
mocked.send.return_value = Future().failure(Exception())
conn.return_value = mocked
conn.recv.return_value = []


class TestSimpleClient(unittest.TestCase):
Expand Down Expand Up @@ -94,7 +95,7 @@ def test_send_broker_unaware_request(self):
mock_conn(mocked_conns[('kafka03', 9092)], success=False)
future = Future()
mocked_conns[('kafka02', 9092)].send.return_value = future
mocked_conns[('kafka02', 9092)].recv.side_effect = lambda: future.success('valid response')
mocked_conns[('kafka02', 9092)].recv.return_value = [('valid response', future)]

def mock_get_conn(host, port, afi):
return mocked_conns[(host, port)]
Expand Down