Skip to content

Commit

Permalink
Format connection errors in the same way everywhere (#3305)
Browse files Browse the repository at this point in the history
Connection errors are formatted in four places, sync and async, network
socket and unix socket. Each place has some small differences compared
to the others, while they could be, and should be, formatted in an
uniform way. Factor out the logic in a helper method and call that
method in all four places. Arguably we lose some specificity, e.g. the
words "unix socket" won't be there anymore, but it is more valuable to not
have code duplication.
  • Loading branch information
gerzse authored Jul 4, 2024
1 parent 04962e0 commit 0be67bf
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 81 deletions.
40 changes: 3 additions & 37 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
)
from urllib.parse import ParseResult, parse_qs, unquote, urlparse

from ..utils import format_error_message

# the functionality is available in 3.11.x but has a major issue before
# 3.11.3. See https://github.com/redis/redis-py/issues/2633
if sys.version_info >= (3, 11, 3):
Expand Down Expand Up @@ -345,9 +347,8 @@ async def _connect(self):
def _host_error(self) -> str:
pass

@abstractmethod
def _error_message(self, exception: BaseException) -> str:
pass
return format_error_message(self._host_error(), exception)

async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
Expand Down Expand Up @@ -799,27 +800,6 @@ async def _connect(self):
def _host_error(self) -> str:
return f"{self.host}:{self.port}"

def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"

host_error = self._host_error()

if not exception.args:
# asyncio has a bug where on Connection reset by peer, the
# exception is not instanciated, so args is empty. This is the
# workaround.
# See: https://github.com/redis/redis-py/issues/2237
# See: https://github.com/python/cpython/issues/94061
return f"Error connecting to {host_error}. Connection reset by peer"
elif len(exception.args) == 1:
return f"Error connecting to {host_error}. {exception.args[0]}."
else:
return (
f"Error {exception.args[0]} connecting to {host_error}. "
f"{exception}."
)


class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
Expand Down Expand Up @@ -971,20 +951,6 @@ async def _connect(self):
def _host_error(self) -> str:
return self.path

def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if len(exception.args) == 1:
return (
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
)
else:
return (
f"Error {exception.args[0]} connecting to unix socket: "
f"{host_error}. {exception.args[1]}."
)


FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")

Expand Down
39 changes: 2 additions & 37 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
HIREDIS_AVAILABLE,
HIREDIS_PACK_AVAILABLE,
SSL_AVAILABLE,
format_error_message,
get_lib_version,
str_if_bytes,
)
Expand Down Expand Up @@ -338,9 +339,8 @@ def _connect(self):
def _host_error(self):
pass

@abstractmethod
def _error_message(self, exception):
pass
return format_error_message(self._host_error(), exception)

def on_connect(self):
"Initialize the connection, authenticate and select a database"
Expand Down Expand Up @@ -733,27 +733,6 @@ def _connect(self):
def _host_error(self):
return f"{self.host}:{self.port}"

def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"

host_error = self._host_error()

if len(exception.args) == 1:
try:
return f"Error connecting to {host_error}. \
{exception.args[0]}."
except AttributeError:
return f"Connection Error: {exception.args[0]}"
else:
try:
return (
f"Error {exception.args[0]} connecting to "
f"{host_error}. {exception.args[1]}."
)
except AttributeError:
return f"Connection Error: {exception.args[0]}"


class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
Expand Down Expand Up @@ -930,20 +909,6 @@ def _connect(self):
def _host_error(self):
return self.path

def _error_message(self, exception):
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if len(exception.args) == 1:
return (
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
)
else:
return (
f"Error {exception.args[0]} connecting to unix socket: "
f"{host_error}. {exception.args[1]}."
)


FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")

Expand Down
12 changes: 12 additions & 0 deletions redis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,15 @@ def get_lib_version():
except metadata.PackageNotFoundError:
libver = "99.99.99"
return libver


def format_error_message(host_error: str, exception: BaseException) -> str:
if not exception.args:
return f"Error connecting to {host_error}."
elif len(exception.args) == 1:
return f"Error {exception.args[0]} connecting to {host_error}."
else:
return (
f"Error {exception.args[0]} connecting to {host_error}. "
f"{exception.args[1]}."
)
51 changes: 44 additions & 7 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
_AsyncRESPBase,
)
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.connection import Connection, UnixDomainSocketConnection, parse_url
from redis.asyncio.connection import (
Connection,
SSLConnection,
UnixDomainSocketConnection,
parse_url,
)
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
Expand Down Expand Up @@ -494,18 +499,50 @@ async def test_connection_garbage_collection(request):


@pytest.mark.parametrize(
"error, expected_message",
"conn, error, expected_message",
[
(OSError(), "Error connecting to localhost:6379. Connection reset by peer"),
(OSError(12), "Error connecting to localhost:6379. 12."),
(SSLConnection(), OSError(), "Error connecting to localhost:6379."),
(SSLConnection(), OSError(12), "Error 12 connecting to localhost:6379."),
(
SSLConnection(),
OSError(12, "Some Error"),
"Error 12 connecting to localhost:6379. [Errno 12] Some Error.",
"Error 12 connecting to localhost:6379. Some Error.",
),
(
UnixDomainSocketConnection(path="unix:///tmp/redis.sock"),
OSError(),
"Error connecting to unix:///tmp/redis.sock.",
),
(
UnixDomainSocketConnection(path="unix:///tmp/redis.sock"),
OSError(12),
"Error 12 connecting to unix:///tmp/redis.sock.",
),
(
UnixDomainSocketConnection(path="unix:///tmp/redis.sock"),
OSError(12, "Some Error"),
"Error 12 connecting to unix:///tmp/redis.sock. Some Error.",
),
],
)
async def test_connect_error_message(error, expected_message):
async def test_format_error_message(conn, error, expected_message):
"""Test that the _error_message function formats errors correctly"""
conn = Connection()
error_message = conn._error_message(error)
assert error_message == expected_message


async def test_network_connection_failure():
with pytest.raises(ConnectionError) as e:
redis = Redis(host="127.0.0.1", port=9999)
await redis.set("a", "b")
assert str(e.value).startswith("Error 111 connecting to 127.0.0.1:9999. Connect")


async def test_unix_socket_connection_failure():
with pytest.raises(ConnectionError) as e:
redis = Redis(unix_socket_path="unix:///tmp/a.sock")
await redis.set("a", "b")
assert (
str(e.value)
== "Error 2 connecting to unix:///tmp/a.sock. No such file or directory."
)
50 changes: 50 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,53 @@ def mock_disconnect(_):

assert called == 1
pool.disconnect()


@pytest.mark.parametrize(
"conn, error, expected_message",
[
(SSLConnection(), OSError(), "Error connecting to localhost:6379."),
(SSLConnection(), OSError(12), "Error 12 connecting to localhost:6379."),
(
SSLConnection(),
OSError(12, "Some Error"),
"Error 12 connecting to localhost:6379. Some Error.",
),
(
UnixDomainSocketConnection(path="unix:///tmp/redis.sock"),
OSError(),
"Error connecting to unix:///tmp/redis.sock.",
),
(
UnixDomainSocketConnection(path="unix:///tmp/redis.sock"),
OSError(12),
"Error 12 connecting to unix:///tmp/redis.sock.",
),
(
UnixDomainSocketConnection(path="unix:///tmp/redis.sock"),
OSError(12, "Some Error"),
"Error 12 connecting to unix:///tmp/redis.sock. Some Error.",
),
],
)
def test_format_error_message(conn, error, expected_message):
"""Test that the _error_message function formats errors correctly"""
error_message = conn._error_message(error)
assert error_message == expected_message


def test_network_connection_failure():
with pytest.raises(ConnectionError) as e:
redis = Redis(port=9999)
redis.set("a", "b")
assert str(e.value) == "Error 111 connecting to localhost:9999. Connection refused."


def test_unix_socket_connection_failure():
with pytest.raises(ConnectionError) as e:
redis = Redis(unix_socket_path="unix:///tmp/a.sock")
redis.set("a", "b")
assert (
str(e.value)
== "Error 2 connecting to unix:///tmp/a.sock. No such file or directory."
)

0 comments on commit 0be67bf

Please sign in to comment.