diff --git a/sbot/serial_wrapper.py b/sbot/serial_wrapper.py index 38333f9..fc31b5f 100644 --- a/sbot/serial_wrapper.py +++ b/sbot/serial_wrapper.py @@ -6,6 +6,7 @@ """ from __future__ import annotations +import itertools import logging import sys import threading @@ -16,7 +17,6 @@ import serial from .exceptions import BoardDisconnectionError -from .logging import TRACE from .utils import IN_SIMULATOR, BoardIdentity logger = logging.getLogger(__name__) @@ -122,51 +122,88 @@ def stop(self) -> None: """ self._disconnect() + def _connect_if_needed(self) -> None: + if not self.serial.is_open: + if not self._connect(): + # If the serial port cannot be opened raise an error, + # this will be caught by the retry decorator + raise BoardDisconnectionError(( + f'Connection to board {self.identity.board_type}:' + f'{self.identity.asset_tag} could not be established', + )) + @retry(times=3, exceptions=(BoardDisconnectionError, UnicodeDecodeError)) - def query(self, data: str) -> str: + def query_multi(self, commands: list[str]) -> list[str]: """ Send a command to the board and return the response. - This method will automatically reconnect to the board and retry the command + This method will automatically reconnect to the board and retry the commands up to 3 times on serial errors. - :param data: The data to write to the board. + :param commands: The commands to write to the board. :raises BoardDisconnectionError: If the serial connection fails during the transaction, including failing to respond to the command. - :return: The response from the board with the trailing newline removed. + :return: The responses from the board with the trailing newlines removed. """ + # Verify no command has a newline in it, and build a command `bytes` from the + # list of commands + encoded_commands: list[bytes] = [] + invalid_commands: list[tuple[str, str]] = [] + + for command in commands: + if '\n' in command: + invalid_commands.append(("contains newline", command)) + else: + try: + byte_form = command.encode(encoding='utf-8') + except UnicodeEncodeError as e: + invalid_commands.append((str(e), command)) + else: + encoded_commands.append(byte_form) + encoded_commands.append(b'\n') + + if invalid_commands: + invalid_commands.sort() + + invalid_command_groups = dict(itertools.groupby( + invalid_commands, + key=lambda x: x[0], + )) + + error_message = "\n".join( + ["Invalid commands:"] + + [ + f" {reason}: " + ", ".join( + repr(command) + for _, command in grouped_commands + ) + for reason, grouped_commands in invalid_command_groups.items() + ], + ) + raise ValueError(error_message) + + full_commands = b''.join(encoded_commands) + with self._lock: - if not self.serial.is_open: - if not self._connect(): - # If the serial port cannot be opened raise an error, - # this will be caught by the retry decorator - raise BoardDisconnectionError(( - f'Connection to board {self.identity.board_type}:' - f'{self.identity.asset_tag} could not be established', - )) + # If the serial port is not open, try to connect + self._connect_if_needed() # TODO: Write me + # Contain all the serial IO in a try-catch; on error, disconnect and raise an error try: - logger.log(TRACE, f'Serial write - {data!r}') - cmd = data + '\n' - self.serial.write(cmd.encode()) - - response = self.serial.readline() - try: - response_str = response.decode().rstrip('\n') - except UnicodeDecodeError as e: - logger.warning( - f"Board {self.identity.board_type}:{self.identity.asset_tag} " - f"returned invalid characters: {response!r}") - raise e - logger.log( - TRACE, f'Serial read - {response_str!r}') - - if b'\n' not in response: - # If readline times out no error is raised, it returns an incomplete string - logger.warning(( - f'Connection to board {self.identity.board_type}:' - f'{self.identity.asset_tag} timed out waiting for response' - )) + # Send the commands to the board + self.serial.write(full_commands) + + # Read as many lines as there are commands + responses_binary = [ + self.serial.readline() + for _ in range(len(commands)) + ] + + # Check all responses have a trailing newline (an incomplete + # response will not). + # This is within the lock and try-catch to ensure the serial port + # is closed on error. + if not all(response.endswith(b'\n') for response in responses_binary): raise serial.SerialException('Timeout on readline') except serial.SerialException: # Serial connection failed, close the port and raise an error @@ -176,15 +213,51 @@ def query(self, data: str) -> str: 'disconnected during transaction' )) - if response_str.startswith('NACK'): - _, error_msg = response_str.split(':', maxsplit=1) - logger.error(( - f'Board {self.identity.board_type}:{self.identity.asset_tag} ' - f'returned NACK on write command: {error_msg}' - )) - raise RuntimeError(error_msg) + # Decode all the responses as UTF-8 + try: + responses_decoded = [ + response.decode("utf-8").rstrip('\n') + for response in responses_binary + ] + except UnicodeDecodeError as e: + logger.warning( + f"Board {self.identity.board_type}:{self.identity.asset_tag} " + f"returned invalid characters: {responses_binary!r}") + raise e + + # Collect any NACK responses; if any, raise an error + nack_prefix = 'NACK:' + nack_responses = [ + response + for response in responses_decoded + if response.startswith(nack_prefix) + ] + + if nack_responses: + errors = [response[len(nack_prefix):] for response in nack_responses] + # We can't use exception groups due to needing to support Python 3.8 + raise ( + RuntimeError(errors[0]) + if len(errors) == 1 + else RuntimeError("Multiple errors: " + ", ".join(errors)) + ) + + # Return the list of responses + return responses_decoded + + def query(self, data: str) -> str: + """ + Send a command to the board and return the response. + + This method will automatically reconnect to the board and retry the command + up to 3 times on serial errors. - return response_str + :param data: The data to write to the board. + :raises BoardDisconnectionError: If the serial connection fails during the transaction, + including failing to respond to the command. + :return: The response from the board with the trailing newline removed. + """ + return self.query_multi([data])[0] def write(self, data: str) -> None: """