Skip to content

SentinelManagedConnection searches for new master upon connection failure (#3560) #3601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions redis/asyncio/sentinel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import inspect
import random
import socket
import weakref
from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type

Expand All @@ -11,8 +13,13 @@
SSLConnection,
)
from redis.commands import AsyncSentinelCommands
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
from redis.utils import str_if_bytes
from redis.exceptions import (
ConnectionError,
ReadOnlyError,
RedisError,
ResponseError,
TimeoutError,
)


class MasterNotFoundError(ConnectionError):
Expand All @@ -37,11 +44,47 @@ def __repr__(self):

async def connect_to(self, address):
self.host, self.port = address
await super().connect()
if self.connection_pool.check_connection:
await self.send_command("PING")
if str_if_bytes(await self.read_response()) != "PONG":
raise ConnectionError("PING failed")

if self.is_connected:
return
try:
await self._connect()
except asyncio.CancelledError:
raise # in 3.7 and earlier, this is an Exception, not BaseException
except (socket.timeout, asyncio.TimeoutError):
raise TimeoutError("Timeout connecting to server")
except OSError as e:
raise ConnectionError(self._error_message(e))
except Exception as exc:
raise ConnectionError(exc) from exc

try:
if not self.redis_connect_func:
# Use the default on_connect function
await self.on_connect_check_health(
check_health=self.connection_pool.check_connection
)
else:
# Use the passed function redis_connect_func
(
await self.redis_connect_func(self)
if asyncio.iscoroutinefunction(self.redis_connect_func)
else self.redis_connect_func(self)
)
except RedisError:
# clean up after any error in on_connect
await self.disconnect()
raise

# run any user callbacks. right now the only internal callback
# is for pubsub channel/pattern resubscription
# first, remove any dead weakrefs
self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
for ref in self._connect_callbacks:
callback = ref()
task = callback(self)
if task and inspect.isawaitable(task):
await task

async def _connect_retry(self):
if self._reader:
Expand Down
52 changes: 45 additions & 7 deletions redis/sentinel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import random
import socket
import weakref
from typing import Optional

from redis.client import Redis
from redis.commands import SentinelCommands
from redis.connection import Connection, ConnectionPool, SSLConnection
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
from redis.utils import str_if_bytes
from redis.exceptions import (
ConnectionError,
ReadOnlyError,
RedisError,
ResponseError,
TimeoutError,
)


class MasterNotFoundError(ConnectionError):
Expand Down Expand Up @@ -35,11 +41,39 @@ def __repr__(self):

def connect_to(self, address):
self.host, self.port = address
super().connect()
if self.connection_pool.check_connection:
self.send_command("PING")
if str_if_bytes(self.read_response()) != "PONG":
raise ConnectionError("PING failed")

if self._sock:
return
try:
sock = self._connect()
except socket.timeout:
raise TimeoutError("Timeout connecting to server")
except OSError as e:
raise ConnectionError(self._error_message(e))

self._sock = sock
try:
if self.redis_connect_func is None:
# Use the default on_connect function
self.on_connect_check_health(
check_health=self.connection_pool.check_connection
)
else:
# Use the passed function redis_connect_func
self.redis_connect_func(self)
except RedisError:
# clean up after any error in on_connect
self.disconnect()
raise

# run any user callbacks. right now the only internal callback
# is for pubsub channel/pattern resubscription
# first, remove any dead weakrefs
self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()]
for ref in self._connect_callbacks:
callback = ref()
if callback:
callback(self)

def _connect_retry(self):
if self._sock:
Expand Down Expand Up @@ -294,13 +328,16 @@ def discover_master(self, service_name):
"""
collected_errors = list()
for sentinel_no, sentinel in enumerate(self.sentinels):
# print(f"Sentinel: {sentinel_no}")
try:
masters = sentinel.sentinel_masters()
except (ConnectionError, TimeoutError) as e:
collected_errors.append(f"{sentinel} - {e!r}")
continue
state = masters.get(service_name)
# print(f"Found master: {state}")
if state and self.check_master_state(state, service_name):
# print("Valid state")
# Put this sentinel at the top of the list
self.sentinels[0], self.sentinels[sentinel_no] = (
sentinel,
Expand All @@ -313,6 +350,7 @@ def discover_master(self, service_name):
else state["ip"]
)
return ip, state["port"]
# print("Ignoring it")

error_info = ""
if len(collected_errors) > 0:
Expand Down
1 change: 1 addition & 0 deletions tests/test_asyncio/test_sentinel_managed_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ async def mock_connect():
conn._connect.side_effect = mock_connect
await conn.connect()
assert conn._connect.call_count == 3
assert connection_pool.get_master_address.call_count == 3
await conn.disconnect()
34 changes: 34 additions & 0 deletions tests/test_sentinel_managed_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import socket

from redis.retry import Retry
from redis.sentinel import SentinelManagedConnection
from redis.backoff import NoBackoff
from unittest import mock


def test_connect_retry_on_timeout_error(master_host):
"""Test that the _connect function is retried in case of a timeout"""
connection_pool = mock.Mock()
connection_pool.get_master_address = mock.Mock(
return_value=(master_host[0], master_host[1])
)
conn = SentinelManagedConnection(
retry_on_timeout=True,
retry=Retry(NoBackoff(), 3),
connection_pool=connection_pool,
)
origin_connect = conn._connect
conn._connect = mock.Mock()

def mock_connect():
# connect only on the last retry
if conn._connect.call_count <= 2:
raise socket.timeout
else:
return origin_connect()

conn._connect.side_effect = mock_connect
conn.connect()
assert conn._connect.call_count == 3
assert connection_pool.get_master_address.call_count == 3
conn.disconnect()
Loading