Skip to content

Commit

Permalink
allow per-request timeouts; use api_version_auto_timeout_ms
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkp committed Feb 25, 2025
1 parent 06517ce commit 1cc78c2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
30 changes: 16 additions & 14 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class BrokerConnection(object):
(0, 10). Default: (0, 8, 2)
api_version_auto_timeout_ms (int): number of milliseconds to throw a
timeout exception from the constructor when checking the broker
api version. Only applies if api_version is None
api version. Only applies if api_version is None. Default: 2000.
selector (selectors.BaseSelector): Provide a specific selector
implementation to use for I/O multiplexing.
Default: selectors.DefaultSelector
Expand Down Expand Up @@ -215,6 +215,7 @@ class BrokerConnection(object):
'ssl_password': None,
'ssl_ciphers': None,
'api_version': (0, 8, 2), # default to most restrictive
'api_version_auto_timeout_ms': 2000,
'selector': selectors.DefaultSelector,
'state_change_callback': lambda node_id, sock, conn: True,
'metrics': None,
Expand Down Expand Up @@ -543,14 +544,14 @@ def _try_api_versions_check(self):
# ((0, 10), ApiVersionRequest[0]()),
request = ApiVersionRequest[0]()
future = Future()
response = self._send(request, blocking=True)
response = self._send(request, blocking=True, request_timeout_ms=self.config['api_version_auto_timeout_ms'])
response.add_callback(self._handle_api_versions_response, future)
response.add_errback(self._handle_api_versions_failure, future)
self._api_versions_future = future
elif self._check_version_idx < len(self.VERSION_CHECKS):
version, request = self.VERSION_CHECKS[self._check_version_idx]
future = Future()
response = self._send(request, blocking=True)
response = self._send(request, blocking=True, request_timeout_ms=self.config['api_version_auto_timeout_ms'])
response.add_callback(self._handle_check_version_response, future, version)
response.add_errback(self._handle_check_version_failure, future)
self._api_versions_future = future
Expand Down Expand Up @@ -1038,14 +1039,14 @@ def close(self, error=None):
# drop lock before state change callback and processing futures
self.config['state_change_callback'](self.node_id, sock, self)
sock.close()
for (_correlation_id, (future, _timestamp)) in ifrs:
for (_correlation_id, (future, _timestamp, _timeout)) in ifrs:
future.failure(error)

def _can_send_recv(self):
"""Return True iff socket is ready for requests / responses"""
return self.connected() or self.initializing()

def send(self, request, blocking=True):
def send(self, request, blocking=True, request_timeout_ms=None):
"""Queue request for async network send, return Future()"""
future = Future()
if self.connecting():
Expand All @@ -1054,9 +1055,9 @@ def send(self, request, blocking=True):
return future.failure(Errors.KafkaConnectionError(str(self)))
elif not self.can_send_more():
return future.failure(Errors.TooManyInFlightRequests(str(self)))
return self._send(request, blocking=blocking)
return self._send(request, blocking=blocking, request_timeout_ms=request_timeout_ms)

def _send(self, request, blocking=True):
def _send(self, request, blocking=True, request_timeout_ms=None):
future = Future()
with self._lock:
if not self._can_send_recv():
Expand All @@ -1069,9 +1070,11 @@ def _send(self, request, blocking=True):

log.debug('%s Request %d: %s', self, correlation_id, request)
if request.expect_response():
sent_time = time.time()
assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
self.in_flight_requests[correlation_id] = (future, sent_time)
sent_time = time.time()
request_timeout_ms = request_timeout_ms or self.config['request_timeout_ms']
timeout_at = sent_time + (request_timeout_ms / 1000)
self.in_flight_requests[correlation_id] = (future, sent_time, timeout_at)
else:
future.success(None)

Expand Down Expand Up @@ -1161,7 +1164,7 @@ def recv(self):
for i, (correlation_id, response) in enumerate(responses):
try:
with self._lock:
(future, timestamp) = self.in_flight_requests.pop(correlation_id)
(future, timestamp, _timeout) = self.in_flight_requests.pop(correlation_id)
except KeyError:
self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
return ()
Expand Down Expand Up @@ -1235,10 +1238,9 @@ def requests_timed_out(self):
def next_ifr_request_timeout_ms(self):
with self._lock:
if self.in_flight_requests:
get_timestamp = lambda v: v[1]
oldest_at = min(map(get_timestamp,
self.in_flight_requests.values()))
next_timeout = oldest_at + self.config['request_timeout_ms'] / 1000.0
get_timeout = lambda v: v[2]
next_timeout = min(map(get_timeout,
self.in_flight_requests.values()))
return max(0, (next_timeout - time.time()) * 1000)
else:
return float('inf')
Expand Down
6 changes: 3 additions & 3 deletions test/test_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,14 +351,14 @@ def test_requests_timed_out(conn):
# No in-flight requests, not timed out
assert not conn.requests_timed_out()

# Single request, timestamp = now (0)
conn.in_flight_requests[0] = ('foo', 0)
# Single request, timeout_at > now (0)
conn.in_flight_requests[0] = ('foo', 0, 1)
assert not conn.requests_timed_out()

# Add another request w/ timestamp > request_timeout ago
request_timeout = conn.config['request_timeout_ms']
expired_timestamp = 0 - request_timeout - 1
conn.in_flight_requests[1] = ('bar', expired_timestamp)
conn.in_flight_requests[1] = ('bar', 0, expired_timestamp)
assert conn.requests_timed_out()

# Drop the expired request and we should be good to go again
Expand Down

0 comments on commit 1cc78c2

Please sign in to comment.