Skip to content

Commit

Permalink
Raise better exception if losing connection
Browse files Browse the repository at this point in the history
* Handle disconnects better in old ritz code
* Have the manager raise LostConnectionError for known trouble spots
* Make LostConnectionError a NotConnectedError
  • Loading branch information
hmpf authored Jun 6, 2024
1 parent 8b027ad commit 4544657
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
8 changes: 7 additions & 1 deletion src/zinolib/controllers/zino1.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
from .base import EventManager, EventOrId
from ..compat import StrEnum
from ..event_types import EventType, Event, HistoryEntry, LogEntry, AdmState
from ..ritz import ZinoError, ProtocolError, ritz, notifier
from ..ritz import ZinoError, ProtocolError, ritz, notifier, NotConnectedError
from ..utils import log_exception_with_params


Expand Down Expand Up @@ -120,6 +120,10 @@ class EventClosedError(Zino1Error):
pass


class LostConnectionError(NotConnectedError):
pass


def convert_timestamp(timestamp: int) -> datetime:
return datetime.fromtimestamp(timestamp, timezone.utc)

Expand Down Expand Up @@ -370,6 +374,8 @@ def get_event_ids(request):
return request.get_caseids()
except ProtocolError as e:
raise RetryError('Zino 1 failed to send a correct response header, retry') from e
except BrokenPipeError as e:
raise LostConnectionError('Lost connection to Zino 1 server') from e

@staticmethod
def poll(request, event: EventType) -> bool:
Expand Down
17 changes: 11 additions & 6 deletions src/zinolib/ritz.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,11 @@ def _request(self, command: bytes, recv_buffer=4096, **_):
while data:
try:
data = self._sock.recv(recv_buffer)
except socket.timeout:
except socket.timeout as e:
raise TimeoutError(
"Timed out waiting for data. command: %s buffer: %s"
% (repr(command), repr(buffer))
)
) from e
logger.debug("recv: %s" % data.__repr__())

buffer += data.decode("UTF-8", errors="windows_codepage_cp1252")
Expand Down Expand Up @@ -439,8 +439,8 @@ def connect(self):
self._sock = socket.create_connection(
(self.server, self.port), self.timeout
)
except socket.gaierror as E:
raise NotConnectedError(E)
except socket.gaierror as e:
raise NotConnectedError(e) from e
response = self._request(None)
if response.header[0] == 200:
self.authChallenge = response.header[1].split(" ", 1)[0]
Expand Down Expand Up @@ -1119,6 +1119,8 @@ def connect(self):
(self.zino_session.server, self.port), self.timeout
)
self._buff = self._sock.recv(4096)
if not self._buff:
raise NotConnectedError("Lost connection to server")
self._sock.setblocking(False)
rawHeader = self._buff.split(bytes(self.DELIMITER, 'ascii'))[0]
header = rawHeader.split(b" ", 1)
Expand Down Expand Up @@ -1146,8 +1148,11 @@ def poll(self, timeout=0):
r, _, _ = select.select([self._sock], [], [], timeout)
if r:
try:
self._buff += self._sock.recv(4096).decode()
except socket.error as e:
newbuff = self._sock.recv(4096).decode()
if not newbuff:
raise NotConnectedError("Lost connection to server")
self._buff += newbuff
except OSError as e:
if not (
e.args[0] == errno.EAGAIN or e.args[0] == errno.EWOULDBLOCK
):
Expand Down

0 comments on commit 4544657

Please sign in to comment.