Skip to content

Commit

Permalink
Do not call state_change_callback with lock (#1775)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkp authored Apr 3, 2019
1 parent 27cd93b commit 91d3149
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 31 deletions.
16 changes: 8 additions & 8 deletions kafka/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,16 @@ def _can_connect(self, node_id):
conn = self._conns[node_id]
return conn.disconnected() and not conn.blacked_out()

def _conn_state_change(self, node_id, conn):
def _conn_state_change(self, node_id, sock, conn):
with self._lock:
if conn.connecting():
# SSL connections can enter this state 2x (second during Handshake)
if node_id not in self._connecting:
self._connecting.add(node_id)
try:
self._selector.register(conn._sock, selectors.EVENT_WRITE)
self._selector.register(sock, selectors.EVENT_WRITE)
except KeyError:
self._selector.modify(conn._sock, selectors.EVENT_WRITE)
self._selector.modify(sock, selectors.EVENT_WRITE)

if self.cluster.is_bootstrap(node_id):
self._last_bootstrap = time.time()
Expand All @@ -280,9 +280,9 @@ def _conn_state_change(self, node_id, conn):
self._connecting.remove(node_id)

try:
self._selector.modify(conn._sock, selectors.EVENT_READ, conn)
self._selector.modify(sock, selectors.EVENT_READ, conn)
except KeyError:
self._selector.register(conn._sock, selectors.EVENT_READ, conn)
self._selector.register(sock, selectors.EVENT_READ, conn)

if self._sensors:
self._sensors.connection_created.record()
Expand All @@ -298,11 +298,11 @@ def _conn_state_change(self, node_id, conn):
self._conns.pop(node_id).close()

# Connection failures imply that our metadata is stale, so let's refresh
elif conn.state is ConnectionStates.DISCONNECTING:
elif conn.state is ConnectionStates.DISCONNECTED:
if node_id in self._connecting:
self._connecting.remove(node_id)
try:
self._selector.unregister(conn._sock)
self._selector.unregister(sock)
except KeyError:
pass

Expand Down Expand Up @@ -369,7 +369,7 @@ def _maybe_connect(self, node_id):
log.debug("Initiating connection to node %s at %s:%s",
node_id, broker.host, broker.port)
host, port, afi = get_ip_port_afi(broker.host)
cb = functools.partial(WeakMethod(self._conn_state_change), node_id)
cb = WeakMethod(self._conn_state_change)
conn = BrokerConnection(host, broker.port, afi,
state_change_callback=cb,
node_id=node_id,
Expand Down
34 changes: 21 additions & 13 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class BrokerConnection(object):
'ssl_ciphers': None,
'api_version': (0, 8, 2), # default to most restrictive
'selector': selectors.DefaultSelector,
'state_change_callback': lambda conn: True,
'state_change_callback': lambda node_id, sock, conn: True,
'metrics': None,
'metric_group_prefix': '',
'sasl_mechanism': None,
Expand Down Expand Up @@ -357,6 +357,7 @@ def connect(self):
return self.state
else:
log.debug('%s: creating new socket', self)
assert self._sock is None
self._sock_afi, self._sock_addr = next_lookup
self._sock = socket.socket(self._sock_afi, socket.SOCK_STREAM)

Expand All @@ -366,7 +367,7 @@ def connect(self):

self._sock.setblocking(False)
self.state = ConnectionStates.CONNECTING
self.config['state_change_callback'](self)
self.config['state_change_callback'](self.node_id, self._sock, self)
log.info('%s: connecting to %s:%d [%s %s]', self, self.host,
self.port, self._sock_addr, AFI_NAMES[self._sock_afi])

Expand All @@ -386,21 +387,21 @@ def connect(self):
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
log.debug('%s: initiating SSL handshake', self)
self.state = ConnectionStates.HANDSHAKE
self.config['state_change_callback'](self)
self.config['state_change_callback'](self.node_id, self._sock, self)
# _wrap_ssl can alter the connection state -- disconnects on failure
self._wrap_ssl()

elif self.config['security_protocol'] == 'SASL_PLAINTEXT':
log.debug('%s: initiating SASL authentication', self)
self.state = ConnectionStates.AUTHENTICATING
self.config['state_change_callback'](self)
self.config['state_change_callback'](self.node_id, self._sock, self)

else:
# security_protocol PLAINTEXT
log.info('%s: Connection complete.', self)
self.state = ConnectionStates.CONNECTED
self._reset_reconnect_backoff()
self.config['state_change_callback'](self)
self.config['state_change_callback'](self.node_id, self._sock, self)

# Connection failed
# WSAEINVAL == 10022, but errno.WSAEINVAL is not available on non-win systems
Expand All @@ -425,7 +426,7 @@ def connect(self):
log.info('%s: Connection complete.', self)
self.state = ConnectionStates.CONNECTED
self._reset_reconnect_backoff()
self.config['state_change_callback'](self)
self.config['state_change_callback'](self.node_id, self._sock, self)

if self.state is ConnectionStates.AUTHENTICATING:
assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL')
Expand All @@ -435,7 +436,7 @@ def connect(self):
log.info('%s: Connection complete.', self)
self.state = ConnectionStates.CONNECTED
self._reset_reconnect_backoff()
self.config['state_change_callback'](self)
self.config['state_change_callback'](self.node_id, self._sock, self)

if self.state not in (ConnectionStates.CONNECTED,
ConnectionStates.DISCONNECTED):
Expand Down Expand Up @@ -802,15 +803,13 @@ def close(self, error=None):
will be failed with this exception.
Default: kafka.errors.KafkaConnectionError.
"""
if self.state is ConnectionStates.DISCONNECTED:
return
with self._lock:
if self.state is ConnectionStates.DISCONNECTED:
return
log.info('%s: Closing connection. %s', self, error or '')
self.state = ConnectionStates.DISCONNECTING
self.config['state_change_callback'](self)
self._update_reconnect_backoff()
self._close_socket()
self.state = ConnectionStates.DISCONNECTED
self._sasl_auth_future = None
self._protocol = KafkaProtocol(
client_id=self.config['client_id'],
Expand All @@ -819,9 +818,18 @@ def close(self, error=None):
error = Errors.Cancelled(str(self))
ifrs = list(self.in_flight_requests.items())
self.in_flight_requests.clear()
self.config['state_change_callback'](self)
self.state = ConnectionStates.DISCONNECTED
# To avoid race conditions and/or deadlocks
# keep a reference to the socket but leave it
# open until after the state_change_callback
# This should give clients a change to deregister
# the socket fd from selectors cleanly.
sock = self._sock
self._sock = None

# drop lock before processing futures
# drop lock before state change callback and processing futures
self.config['state_change_callback'](self.node_id, sock, self)
sock.close()
for (_correlation_id, (future, _timestamp)) in ifrs:
future.failure(error)

Expand Down
21 changes: 11 additions & 10 deletions test/test_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,28 +95,29 @@ def test_conn_state_change(mocker, cli, conn):
node_id = 0
cli._conns[node_id] = conn
conn.state = ConnectionStates.CONNECTING
cli._conn_state_change(node_id, conn)
sock = conn._sock
cli._conn_state_change(node_id, sock, conn)
assert node_id in cli._connecting
sel.register.assert_called_with(conn._sock, selectors.EVENT_WRITE)
sel.register.assert_called_with(sock, selectors.EVENT_WRITE)

conn.state = ConnectionStates.CONNECTED
cli._conn_state_change(node_id, conn)
cli._conn_state_change(node_id, sock, conn)
assert node_id not in cli._connecting
sel.modify.assert_called_with(conn._sock, selectors.EVENT_READ, conn)
sel.modify.assert_called_with(sock, selectors.EVENT_READ, conn)

# Failure to connect should trigger metadata update
assert cli.cluster._need_update is False
conn.state = ConnectionStates.DISCONNECTING
cli._conn_state_change(node_id, conn)
conn.state = ConnectionStates.DISCONNECTED
cli._conn_state_change(node_id, sock, conn)
assert node_id not in cli._connecting
assert cli.cluster._need_update is True
sel.unregister.assert_called_with(conn._sock)
sel.unregister.assert_called_with(sock)

conn.state = ConnectionStates.CONNECTING
cli._conn_state_change(node_id, conn)
cli._conn_state_change(node_id, sock, conn)
assert node_id in cli._connecting
conn.state = ConnectionStates.DISCONNECTING
cli._conn_state_change(node_id, conn)
conn.state = ConnectionStates.DISCONNECTED
cli._conn_state_change(node_id, sock, conn)
assert node_id not in cli._connecting


Expand Down

0 comments on commit 91d3149

Please sign in to comment.