diff --git a/bench/unmarshall.py b/bench/unmarshall.py new file mode 100644 index 0000000..cde28b3 --- /dev/null +++ b/bench/unmarshall.py @@ -0,0 +1,22 @@ +import io +import timeit + +from dbus_next._private.unmarshaller import Unmarshaller + +bluez_rssi_message = ( + "6c04010134000000e25389019500000001016f00250000002f6f72672f626c75657a2f686369302f6465" + "765f30385f33415f46325f31455f32425f3631000000020173001f0000006f72672e667265656465736b" + "746f702e444275732e50726f7065727469657300030173001100000050726f706572746965734368616e" + "67656400000000000000080167000873617b73767d617300000007017300040000003a312e3400000000" + "110000006f72672e626c75657a2e446576696365310000000e0000000000000004000000525353490001" + "6e00a7ff000000000000" +) + + +def unmarhsall_bluez_rssi_message(): + Unmarshaller(io.BytesIO(bytes.fromhex(bluez_rssi_message))).unmarshall() + + +count = 1000000 +time = timeit.Timer(unmarhsall_bluez_rssi_message).timeit(count) +print(f"Unmarshalling {count} bluetooth rssi messages took {time} seconds") diff --git a/dbus_next/_private/constants.py b/dbus_next/_private/constants.py index 22a6b80..d9ee0c2 100644 --- a/dbus_next/_private/constants.py +++ b/dbus_next/_private/constants.py @@ -16,3 +16,6 @@ class HeaderField(Enum): SENDER = 7 SIGNATURE = 8 UNIX_FDS = 9 + + +HEADER_NAME_MAP = {field.value: field.name for field in HeaderField} diff --git a/dbus_next/_private/unmarshaller.py b/dbus_next/_private/unmarshaller.py index 38ef657..55b9884 100644 --- a/dbus_next/_private/unmarshaller.py +++ b/dbus_next/_private/unmarshaller.py @@ -1,307 +1,336 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple from ..message import Message -from .constants import HeaderField, LITTLE_ENDIAN, BIG_ENDIAN, PROTOCOL_VERSION -from ..constants import MessageType, MessageFlag -from ..signature import SignatureTree, Variant +from .constants import ( + HeaderField, + LITTLE_ENDIAN, + BIG_ENDIAN, + PROTOCOL_VERSION, + HEADER_NAME_MAP, +) +from ..constants import MessageType, MessageFlag, MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP +from ..signature import SignatureTree, SignatureType, Variant from ..errors import InvalidMessageError import array +import io import socket -from codecs import decode -from struct import unpack_from +import sys +from struct import Struct MAX_UNIX_FDS = 16 +UNPACK_SYMBOL = {LITTLE_ENDIAN: "<", BIG_ENDIAN: ">"} +UNPACK_LENGTHS = {BIG_ENDIAN: Struct(">III"), LITTLE_ENDIAN: Struct(" bytes: + """reads from the socket, storing any fds sent and handling errors + from the read itself""" + unix_fd_list = array.array("i") - self.readers = { - 'y': self.read_byte, - 'b': self.read_boolean, - 'n': self.read_int16, - 'q': self.read_uint16, - 'i': self.read_int32, - 'u': self.read_uint32, - 'x': self.read_int64, - 't': self.read_uint64, - 'd': self.read_double, - 'h': self.read_uint32, - 'o': self.read_string, - 's': self.read_string, - 'g': self.read_signature, - 'a': self.read_array, - '(': self.read_struct, - '{': self.read_dict_entry, - 'v': self.read_variant - } + try: + msg, ancdata, *_ = self.sock.recvmsg( + length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize)) + except BlockingIOError: + raise MarshallerStreamEndError() - def read(self, n, prefetch=False): - """ - Read from underlying socket into buffer and advance offset accordingly. + for level, type_, data in ancdata: + if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS): + continue + unix_fd_list.frombytes(data[:len(data) - (len(data) % unix_fd_list.itemsize)]) + self.unix_fds.extend(list(unix_fd_list)) - :arg n: - Number of bytes to read. If not enough bytes are available in the - buffer, read more from it. - :arg prefetch: - Do not update current offset after reading. + return msg - :returns: - Previous offset (before reading). To get the actual read bytes, - use the returned value and self.buf. - """ - def read_sock(length): - '''reads from the socket, storing any fds sent and handling errors - from the read itself''' - if self.sock is not None: - unix_fd_list = array.array("i") - - try: - msg, ancdata, *_ = self.sock.recvmsg( - length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize)) - except BlockingIOError: - raise MarshallerStreamEndError() - - for level, type_, data in ancdata: - if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS): - continue - unix_fd_list.frombytes(data[:len(data) - (len(data) % unix_fd_list.itemsize)]) - self.unix_fds.extend(list(unix_fd_list)) - - return msg - else: - return self.stream.read(length) - - # store previously read data in a buffer so we can resume on socket - # interruptions - missing_bytes = n - (len(self.buf) - self.offset) - if missing_bytes > 0: - data = read_sock(missing_bytes) - if data == b'': - raise EOFError() - elif data is None: - raise MarshallerStreamEndError() - self.buf.extend(data) - if len(data) != missing_bytes: - raise MarshallerStreamEndError() - prev = self.offset - if not prefetch: - self.offset += n - return prev - - @staticmethod - def _padding(offset, align): + def read_to_offset(self, offset: int) -> None: """ - Get padding bytes to get to the next align bytes mark. - - For any align value, the correct padding formula is: - - (align - (offset % align)) % align + Read from underlying socket into buffer. - However, if align is a power of 2 (always the case here), the slow MOD - operator can be replaced by a bitwise AND: + Raises MarshallerStreamEndError if there is not enough data to be read. - (align - (offset & (align - 1))) & (align - 1) - - Which can be simplified to: + :arg offset: + The offset to read to. If not enough bytes are available in the + buffer, read more from it. - (-offset) & (align - 1) + :returns: + None """ - return (-offset) & (align - 1) - - def align(self, n): - padding = self._padding(self.offset, n) - if padding > 0: - self.read(padding) - - def read_byte(self, _=None): - return self.buf[self.read(1)] - - def read_boolean(self, _=None): - data = self.read_uint32() - if data: - return True + start_len = len(self.buf) + missing_bytes = offset - (start_len - self.offset) + if self.sock is None: + data = self.stream.read(missing_bytes) else: - return False - - def read_int16(self, _=None): - return self.read_ctype('h', 2) - - def read_uint16(self, _=None): - return self.read_ctype('H', 2) + data = self.read_sock(missing_bytes) + if data == b"": + raise EOFError() + if data is None: + raise MarshallerStreamEndError() + self.buf.extend(data) + if len(data) + start_len != offset: + raise MarshallerStreamEndError() - def read_int32(self, _=None): - return self.read_ctype('i', 4) - - def read_uint32(self, _=None): - return self.read_ctype('I', 4) - - def read_int64(self, _=None): - return self.read_ctype('q', 8) - - def read_uint64(self, _=None): - return self.read_ctype('Q', 8) - - def read_double(self, _=None): - return self.read_ctype('d', 8) - - def read_ctype(self, fmt, size): - self.align(size) - if self.endian == LITTLE_ENDIAN: - fmt = '<' + fmt - else: - fmt = '>' + fmt - o = self.read(size) - return unpack_from(fmt, self.buf, o)[0] + def read_boolean(self, _=None): + return bool(self.read_argument(UINT32_SIGNATURE)) def read_string(self, _=None): - str_length = self.read_uint32() - o = self.read(str_length + 1) # read terminating '\0' byte as well - # avoid buffer copies when slicing - str_mem_slice = memoryview(self.buf)[o:o + str_length] - return decode(str_mem_slice) + str_length = self.read_argument(UINT32_SIGNATURE) + str_start = self.offset + # read terminating '\0' byte as well (str_length + 1) + self.offset += str_length + 1 + return self.buf[str_start:str_start + str_length].decode() def read_signature(self, _=None): - signature_len = self.read_byte() - o = self.read(signature_len + 1) # read terminating '\0' byte as well - # avoid buffer copies when slicing - sig_mem_slice = memoryview(self.buf)[o:o + signature_len] - return decode(sig_mem_slice) + signature_len = self.view[self.offset] # byte + o = self.offset + 1 + # read terminating '\0' byte as well (str_length + 1) + self.offset = o + signature_len + 1 + return self.buf[o:o + signature_len].decode() def read_variant(self, _=None): - signature = self.read_signature() - signature_tree = SignatureTree._get(signature) - value = self.read_argument(signature_tree.types[0]) - return Variant(signature_tree, value) + tree = SignatureTree._get(self.read_signature()) + # verify in Variant is only useful on construction not unmarshalling + return Variant(tree, self.read_argument(tree.types[0]), verify=False) - def read_struct(self, type_): - self.align(8) + def read_struct(self, type_: SignatureType): + self.offset += -self.offset & 7 # align 8 + return [self.read_argument(child_type) for child_type in type_.children] - result = [] - for child_type in type_.children: - result.append(self.read_argument(child_type)) + def read_dict_entry(self, type_: SignatureType): + self.offset += -self.offset & 7 # align 8 + return self.read_argument(type_.children[0]), self.read_argument(type_.children[1]) - return result - - def read_dict_entry(self, type_): - self.align(8) - - key = self.read_argument(type_.children[0]) - value = self.read_argument(type_.children[1]) - - return key, value - - def read_array(self, type_): - self.align(4) - array_length = self.read_uint32() + def read_array(self, type_: SignatureType): + self.offset += -self.offset & 3 # align 4 for the array + array_length = self.read_argument(UINT32_SIGNATURE) child_type = type_.children[0] - if child_type.token in 'xtd{(': + if child_type.token in "xtd{(": # the first alignment is not included in the array size - self.align(8) + self.offset += -self.offset & 7 # align 8 + + if child_type.token == "y": + self.offset += array_length + return self.buf[self.offset - array_length:self.offset] beginning_offset = self.offset - result = None - if child_type.token == '{': - result = {} + if child_type.token == "{": + result_dict = {} while self.offset - beginning_offset < array_length: key, value = self.read_dict_entry(child_type) - result[key] = value - elif child_type.token == 'y': - o = self.read(array_length) - # avoid buffer copies when slicing - array_mem_slice = memoryview(self.buf)[o:o + array_length] - result = array_mem_slice.tobytes() - else: - result = [] - while self.offset - beginning_offset < array_length: - result.append(self.read_argument(child_type)) - - return result - - def read_argument(self, type_): - t = type_.token - - if t not in self.readers: - raise Exception(f'dont know how to read yet: "{t}"') - - return self.readers[t](type_) - - def _unmarshall(self): - self.offset = 0 - self.read(16, prefetch=True) - self.endian = self.read_byte() - if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN: - raise InvalidMessageError('Expecting endianness as the first byte') - message_type = MessageType(self.read_byte()) - flags = MessageFlag(self.read_byte()) - - protocol_version = self.read_byte() - + result_dict[key] = value + return result_dict + + result_list = [] + while self.offset - beginning_offset < array_length: + result_list.append(self.read_argument(child_type)) + return result_list + + def read_argument(self, type_: SignatureType) -> Any: + """Dispatch to an argument reader or cast/unpack a C type.""" + token = type_.token + reader, ctype, size, struct = self.readers[token] + if reader: # complex type + return reader(self, type_) + self.offset += size + (-self.offset & (size - 1)) # align + if self.can_cast: + return self.view[self.offset - size:self.offset].cast(ctype)[0] + return struct.unpack_from(self.view, self.offset - size)[0] + + def header_fields(self, header_length): + """Header fields are always a(yv).""" + beginning_offset = self.offset + headers = {} + while self.offset - beginning_offset < header_length: + # Now read the y (byte) of struct (yv) + self.offset += (-self.offset & 7) + 1 # align 8 + 1 for 'y' byte + field_0 = self.view[self.offset - 1] + + # Now read the v (variant) of struct (yv) + signature_len = self.view[self.offset] # byte + o = self.offset + 1 + self.offset += signature_len + 2 # one for the byte, one for the '\0' + tree = SignatureTree._get(self.buf[o:o + signature_len].decode()) + headers[HEADER_NAME_MAP[field_0]] = self.read_argument(tree.types[0]) + return headers + + def _read_header(self): + """Read the header of the message.""" + # Signature is of the header is + # BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT) + self.read_to_offset(HEADER_SIGNATURE_SIZE) + buffer = self.buf + endian = buffer[0] + self.message_type = MESSAGE_TYPE_MAP[buffer[1]] + self.flag = MESSAGE_FLAG_MAP[buffer[2]] + protocol_version = buffer[3] + + if endian != LITTLE_ENDIAN and endian != BIG_ENDIAN: + raise InvalidMessageError( + f"Expecting endianness as the first byte, got {endian} from {buffer}") if protocol_version != PROTOCOL_VERSION: - raise InvalidMessageError(f'got unknown protocol version: {protocol_version}') - - body_len = self.read_uint32() - serial = self.read_uint32() - - header_len = self.read_uint32() - msg_len = header_len + self._padding(header_len, 8) + body_len - self.read(msg_len, prefetch=True) - # backtrack offset since header array length needs to be read again - self.offset -= 4 - - header_fields = {} - for field_struct in self.read_argument(SignatureTree._get('a(yv)').types[0]): - field = HeaderField(field_struct[0]) - header_fields[field.name] = field_struct[1].value - - self.align(8) - - path = header_fields.get(HeaderField.PATH.name) - interface = header_fields.get(HeaderField.INTERFACE.name) - member = header_fields.get(HeaderField.MEMBER.name) - error_name = header_fields.get(HeaderField.ERROR_NAME.name) - reply_serial = header_fields.get(HeaderField.REPLY_SERIAL.name) - destination = header_fields.get(HeaderField.DESTINATION.name) - sender = header_fields.get(HeaderField.SENDER.name) - signature = header_fields.get(HeaderField.SIGNATURE.name, '') - signature_tree = SignatureTree._get(signature) - # unix_fds = header_fields.get(HeaderField.UNIX_FDS.name, 0) - - body = [] - - if body_len: - for type_ in signature_tree.types: - body.append(self.read_argument(type_)) - - self.message = Message(destination=destination, - path=path, - interface=interface, - member=member, - message_type=message_type, - flags=flags, - error_name=error_name, - reply_serial=reply_serial, - sender=sender, - unix_fds=self.unix_fds, - signature=signature_tree, - body=body, - serial=serial) + raise InvalidMessageError(f"got unknown protocol version: {protocol_version}") + + self.body_len, self.serial, self.header_len = UNPACK_LENGTHS[endian].unpack_from(buffer, 4) + self.msg_len = (self.header_len + (-self.header_len & 7) + self.body_len) # align 8 + if IS_BIG_ENDIAN and endian == BIG_ENDIAN: + self.can_cast = True + elif IS_LITTLE_ENDIAN and endian == LITTLE_ENDIAN: + self.can_cast = True + self.readers = self._readers_by_type[endian] + + def _read_body(self): + """Read the body of the message.""" + self.read_to_offset(HEADER_SIGNATURE_SIZE + self.msg_len) + self.view = memoryview(self.buf) + self.offset = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION + header_fields = self.header_fields(self.header_len) + self.offset += -self.offset & 7 # align 8 + tree = SignatureTree._get(header_fields.get(HeaderField.SIGNATURE.name, "")) + self.message = Message( + destination=header_fields.get(HEADER_DESTINATION), + path=header_fields.get(HEADER_PATH), + interface=header_fields.get(HEADER_INTERFACE), + member=header_fields.get(HEADER_MEMBER), + message_type=self.message_type, + flags=self.flag, + error_name=header_fields.get(HEADER_ERROR_NAME), + reply_serial=header_fields.get(HEADER_REPLY_SERIAL), + sender=header_fields.get(HEADER_SENDER), + unix_fds=self.unix_fds, + signature=tree.signature, + body=[self.read_argument(t) for t in tree.types] if self.body_len else [], + serial=self.serial, + ) def unmarshall(self): + """Unmarshall the message. + + The underlying read function will raise MarshallerStreamEndError + if there are not enough bytes in the buffer. This allows unmarshall + to be resumed when more data comes in over the wire. + """ try: - self._unmarshall() - return self.message + if not self.message_type: + self._read_header() + self._read_body() except MarshallerStreamEndError: return None + return self.message + + _complex_parsers: Dict[str, Tuple[Callable[["Unmarshaller", SignatureType], Any], None, None, + None]] = { + "b": (read_boolean, None, None, None), + "o": (read_string, None, None, None), + "s": (read_string, None, None, None), + "g": (read_signature, None, None, None), + "a": (read_array, None, None, None), + "(": (read_struct, None, None, None), + "{": (read_dict_entry, None, None, None), + "v": (read_variant, None, None, None), + } + + _ctype_by_endian: Dict[int, Dict[str, Tuple[None, str, int, Struct]]] = { + endian: { + dbus_type: ( + None, + *ctype_size, + Struct(f"{UNPACK_SYMBOL[endian]}{ctype_size[0]}"), + ) + for dbus_type, ctype_size in DBUS_TO_CTYPE.items() + } + for endian in (BIG_ENDIAN, LITTLE_ENDIAN) + } + + _readers_by_type: Dict[int, READER_TYPE] = { + BIG_ENDIAN: { + **_ctype_by_endian[BIG_ENDIAN], + **_complex_parsers + }, + LITTLE_ENDIAN: { + **_ctype_by_endian[LITTLE_ENDIAN], + **_complex_parsers + }, + } diff --git a/dbus_next/constants.py b/dbus_next/constants.py index 6afc9b8..b9494f8 100644 --- a/dbus_next/constants.py +++ b/dbus_next/constants.py @@ -17,6 +17,9 @@ class MessageType(Enum): SIGNAL = 4 #: A broadcast signal to subscribed connections +MESSAGE_TYPE_MAP = {field.value: field for field in MessageType} + + class MessageFlag(IntFlag): """Flags that affect the behavior of sent and received messages """ @@ -26,6 +29,9 @@ class MessageFlag(IntFlag): ALLOW_INTERACTIVE_AUTHORIZATION = 4 +MESSAGE_FLAG_MAP = {field.value: field for field in MessageFlag} + + class NameFlag(IntFlag): """A flag that affects the behavior of a name request. """ diff --git a/dbus_next/message.py b/dbus_next/message.py index 1f8085b..c43abd8 100644 --- a/dbus_next/message.py +++ b/dbus_next/message.py @@ -7,6 +7,13 @@ from typing import List, Any +REQUIRED_FIELDS = { + MessageType.METHOD_CALL: ('path', 'member'), + MessageType.SIGNAL: ('path', 'member', 'interface'), + MessageType.ERROR: ('error_name', 'reply_serial'), + MessageType.METHOD_RETURN: ('reply_serial', ), +} + class Message: """A class for sending and receiving messages through the @@ -95,21 +102,12 @@ def __init__(self, if self.error_name is not None: assert_interface_name_valid(self.error_name) - def require_fields(*fields): - for field in fields: - if not getattr(self, field): - raise InvalidMessageError(f'missing required field: {field}') - - if self.message_type == MessageType.METHOD_CALL: - require_fields('path', 'member') - elif self.message_type == MessageType.SIGNAL: - require_fields('path', 'member', 'interface') - elif self.message_type == MessageType.ERROR: - require_fields('error_name', 'reply_serial') - elif self.message_type == MessageType.METHOD_RETURN: - require_fields('reply_serial') - else: + required_fields = REQUIRED_FIELDS.get(self.message_type) + if not required_fields: raise InvalidMessageError(f'got unknown message type: {self.message_type}') + for field in required_fields: + if not getattr(self, field): + raise InvalidMessageError(f'missing required field: {field}') @staticmethod def new_error(msg: 'Message', error_name: str, error_text: str) -> 'Message': diff --git a/dbus_next/signature.py b/dbus_next/signature.py index 254c842..dd83bed 100644 --- a/dbus_next/signature.py +++ b/dbus_next/signature.py @@ -1,6 +1,7 @@ from .validators import is_object_path_valid from .errors import InvalidSignatureError, SignatureBodyMismatchError +from functools import lru_cache from typing import Any, List, Union @@ -20,9 +21,9 @@ class to parse signatures. """ _tokens = 'ybnqiuxtdsogavh({' - def __init__(self, token): + def __init__(self, token: str) -> None: self.token = token - self.children = [] + self.children: List[SignatureType] = [] self._signature = None def __eq__(self, other): @@ -216,7 +217,7 @@ def _verify_array(self, body): child_type.children[0].verify(key) child_type.children[1].verify(value) elif child_type.token == 'y': - if not isinstance(body, bytes): + if not isinstance(body, (bytearray, bytes)): raise SignatureBodyMismatchError( f'DBus ARRAY type "a" with BYTE child must be Python type "bytes", got {type(body)}' ) @@ -257,43 +258,33 @@ def verify(self, body: Any) -> bool: """ if body is None: raise SignatureBodyMismatchError('Cannot serialize Python type "None"') - elif self.token == 'y': - self._verify_byte(body) - elif self.token == 'b': - self._verify_boolean(body) - elif self.token == 'n': - self._verify_int16(body) - elif self.token == 'q': - self._verify_uint16(body) - elif self.token == 'i': - self._verify_int32(body) - elif self.token == 'u': - self._verify_uint32(body) - elif self.token == 'x': - self._verify_int64(body) - elif self.token == 't': - self._verify_uint64(body) - elif self.token == 'd': - self._verify_double(body) - elif self.token == 'h': - self._verify_unix_fd(body) - elif self.token == 'o': - self._verify_object_path(body) - elif self.token == 's': - self._verify_string(body) - elif self.token == 'g': - self._verify_signature(body) - elif self.token == 'a': - self._verify_array(body) - elif self.token == '(': - self._verify_struct(body) - elif self.token == 'v': - self._verify_variant(body) + validator = self.validators.get(self.token) + if validator: + validator(self, body) else: raise Exception(f'cannot verify type with token {self.token}') return True + validators = { + "y": _verify_byte, + "b": _verify_boolean, + "n": _verify_int16, + "q": _verify_uint16, + "i": _verify_int32, + "u": _verify_uint32, + "x": _verify_int64, + "t": _verify_uint64, + "d": _verify_double, + "h": _verify_uint32, + "o": _verify_string, + "s": _verify_string, + "g": _verify_signature, + "a": _verify_array, + "(": _verify_struct, + "v": _verify_variant, + } + class SignatureTree: """A class that represents a signature as a tree structure for conveniently @@ -310,20 +301,15 @@ class SignatureTree: :raises: :class:`InvalidSignatureError` if the given signature is not valid. """ - - _cache = {} - @staticmethod - def _get(signature: str = ''): - if signature in SignatureTree._cache: - return SignatureTree._cache[signature] - SignatureTree._cache[signature] = SignatureTree(signature) - return SignatureTree._cache[signature] + @lru_cache(maxsize=None) + def _get(signature: str = '') -> "SignatureTree": + return SignatureTree(signature) def __init__(self, signature: str = ''): self.signature = signature - self.types = [] + self.types: List[SignatureType] = [] if len(signature) > 0xff: raise InvalidSignatureError('A signature must be less than 256 characters') @@ -381,7 +367,10 @@ class Variant: :class:`InvalidSignatureError` if the signature is not valid. :class:`SignatureBodyMismatchError` if the signature does not match the body. """ - def __init__(self, signature: Union[str, SignatureTree, SignatureType], value: Any): + def __init__(self, + signature: Union[str, SignatureTree, SignatureType], + value: Any, + verify: bool = True): signature_str = '' signature_tree = None signature_type = None @@ -397,12 +386,13 @@ def __init__(self, signature: Union[str, SignatureTree, SignatureType], value: A raise TypeError('signature must be a SignatureTree, SignatureType, or a string') if signature_tree: - if len(signature_tree.types) != 1: + if verify and len(signature_tree.types) != 1: raise ValueError('variants must have a signature for a single complete type') signature_str = signature_tree.signature signature_type = signature_tree.types[0] - signature_type.verify(value) + if verify: + signature_type.verify(value) self.type = signature_type self.signature = signature_str diff --git a/dbus_next/validators.py b/dbus_next/validators.py index 3b73127..9460200 100644 --- a/dbus_next/validators.py +++ b/dbus_next/validators.py @@ -1,5 +1,6 @@ import re from .errors import InvalidBusNameError, InvalidObjectPathError, InvalidInterfaceNameError, InvalidMemberNameError +from functools import lru_cache _bus_name_re = re.compile(r'^[A-Za-z_-][A-Za-z0-9_-]*$') _path_re = re.compile(r'^[A-Za-z0-9_]+$') @@ -7,6 +8,7 @@ _member_re = re.compile(r'^[A-Za-z_][A-Za-z0-9_-]*$') +@lru_cache(maxsize=32) def is_bus_name_valid(name: str) -> bool: """Whether this is a valid bus name. @@ -41,6 +43,7 @@ def is_bus_name_valid(name: str) -> bool: return True +@lru_cache(maxsize=1024) def is_object_path_valid(path: str) -> bool: """Whether this is a valid object path. @@ -71,6 +74,7 @@ def is_object_path_valid(path: str) -> bool: return True +@lru_cache(maxsize=32) def is_interface_name_valid(name: str) -> bool: """Whether this is a valid interface name. @@ -101,6 +105,7 @@ def is_interface_name_valid(name: str) -> bool: return True +@lru_cache(maxsize=512) def is_member_name_valid(member: str) -> bool: """Whether this is a valid member name. diff --git a/test/test_marshaller.py b/test/test_marshaller.py index 65f6384..f09cf49 100644 --- a/test/test_marshaller.py +++ b/test/test_marshaller.py @@ -1,10 +1,13 @@ +from typing import Any, Dict from dbus_next._private.unmarshaller import Unmarshaller -from dbus_next import Message, Variant, SignatureTree +from dbus_next import Message, Variant, SignatureTree, MessageType, MessageFlag import json import os import io +import pytest + def print_buf(buf): i = 0 @@ -17,24 +20,36 @@ def print_buf(buf): # these messages have been verified with another library -table = json.load(open(os.path.dirname(__file__) + '/data/messages.json')) +table = json.load(open(os.path.dirname(__file__) + "/data/messages.json")) + + +def json_to_message(message: Dict[str, Any]) -> Message: + copy = dict(message) + if "message_type" in copy: + copy["message_type"] = MessageType(copy["message_type"]) + if "flags" in copy: + copy["flags"] = MessageFlag(copy["flags"]) + + return Message(**copy) # variants are an object in the json def replace_variants(type_, item): - if type_.token == 'v' and type(item) is not Variant: - item = Variant(item['signature'], - replace_variants(SignatureTree(item['signature']).types[0], item['value'])) - elif type_.token == 'a': + if type_.token == "v" and type(item) is not Variant: + item = Variant( + item["signature"], + replace_variants(SignatureTree(item["signature"]).types[0], item["value"]), + ) + elif type_.token == "a": for i, item_child in enumerate(item): - if type_.children[0].token == '{': + if type_.children[0].token == "{": for k, v in item.items(): item[k] = replace_variants(type_.children[0].children[1], v) else: item[i] = replace_variants(type_.children[0], item_child) - elif type_.token == '(': + elif type_.token == "(": for i, item_child in enumerate(item): - if type_.children[0].token == '{': + if type_.children[0].token == "{": assert False else: item[i] = replace_variants(type_.children[i], item_child) @@ -54,7 +69,7 @@ def dumper(obj): def test_marshalling_with_table(): for item in table: - message = Message(**item['message']) + message = json_to_message(item["message"]) body = [] for i, type_ in enumerate(message.signature_tree.types): @@ -62,34 +77,35 @@ def test_marshalling_with_table(): message.body = body buf = message._marshall() - data = bytes.fromhex(item['data']) + data = bytes.fromhex(item["data"]) if buf != data: - print('message:') - print(json_dump(item['message'])) - print('') - print('mine:') + print("message:") + print(json_dump(item["message"])) + print("") + print("mine:") print_buf(bytes(buf)) - print('') - print('theirs:') + print("") + print("theirs:") print_buf(data) assert buf == data -def test_unmarshalling_with_table(): - for item in table: +@pytest.mark.parametrize("unmarshall_table", (table, )) +def test_unmarshalling_with_table(unmarshall_table): + for item in unmarshall_table: - stream = io.BytesIO(bytes.fromhex(item['data'])) + stream = io.BytesIO(bytes.fromhex(item["data"])) unmarshaller = Unmarshaller(stream) try: unmarshaller.unmarshall() except Exception as e: - print('message failed to unmarshall:') - print(json_dump(item['message'])) + print("message failed to unmarshall:") + print(json_dump(item["message"])) raise e - message = Message(**item['message']) + message = json_to_message(item["message"]) body = [] for i, type_ in enumerate(message.signature_tree.types): @@ -97,16 +113,54 @@ def test_unmarshalling_with_table(): message.body = body for attr in [ - 'body', 'signature', 'message_type', 'destination', 'path', 'interface', 'member', - 'flags', 'serial' + "body", + "signature", + "message_type", + "destination", + "path", + "interface", + "member", + "flags", + "serial", ]: assert getattr(unmarshaller.message, - attr) == getattr(message, attr), f'attr doesnt match: {attr}' + attr) == getattr(message, attr), f"attr doesnt match: {attr}" + + +def test_unmarshall_can_resume(): + """Verify resume works.""" + bluez_rssi_message = ( + "6c04010134000000e25389019500000001016f00250000002f6f72672f626c75657a2f686369302f6465" + "765f30385f33415f46325f31455f32425f3631000000020173001f0000006f72672e667265656465736b" + "746f702e444275732e50726f7065727469657300030173001100000050726f706572746965734368616e" + "67656400000000000000080167000873617b73767d617300000007017300040000003a312e3400000000" + "110000006f72672e626c75657a2e446576696365310000000e0000000000000004000000525353490001" + "6e00a7ff000000000000") + message_bytes = bytes.fromhex(bluez_rssi_message) + + class SlowStream(io.IOBase): + """A fake stream that will only give us one byte at a time.""" + def __init__(self): + self.data = message_bytes + self.pos = 0 + + def read(self, n) -> bytes: + data = self.data[self.pos:self.pos + 1] + self.pos += 1 + return data + + stream = SlowStream() + unmarshaller = Unmarshaller(stream) + + for _ in range(len(bluez_rssi_message)): + if unmarshaller.unmarshall(): + break + assert unmarshaller.message is not None def test_ay_buffer(): body = [bytes(10000)] - msg = Message(path='/test', member='test', signature='ay', body=body) + msg = Message(path="/test", member="test", signature="ay", body=body) marshalled = msg._marshall() unmarshalled_msg = Unmarshaller(io.BytesIO(marshalled)).unmarshall() assert unmarshalled_msg.body[0] == body[0] diff --git a/test/test_validators.py b/test/test_validators.py index e0bfc7c..dd4dde9 100644 --- a/test/test_validators.py +++ b/test/test_validators.py @@ -5,7 +5,7 @@ def test_object_path_validator(): valid_paths = ['/', '/foo', '/foo/bar', '/foo/bar/bat'] invalid_paths = [ - None, {}, '', 'foo', 'foo/bar', '/foo/bar/', '/$/foo/bar', '/foo//bar', '/foo$bar/baz' + None, '', 'foo', 'foo/bar', '/foo/bar/', '/$/foo/bar', '/foo//bar', '/foo$bar/baz' ] for path in valid_paths: @@ -20,7 +20,7 @@ def test_bus_name_validator(): 'org.mpris.MediaPlayer2.google-play-desktop-player' ] invalid_names = [ - None, {}, '', '5foo.bar', 'foo.6bar', '.foo.bar', 'bar..baz', '$foo.bar', 'foo$.ba$r' + None, '', '5foo.bar', 'foo.6bar', '.foo.bar', 'bar..baz', '$foo.bar', 'foo$.ba$r' ] for name in valid_names: @@ -32,7 +32,7 @@ def test_bus_name_validator(): def test_interface_name_validator(): valid_names = ['foo.bar', 'foo.bar.bat', '_foo._bar', 'foo.bar69'] invalid_names = [ - None, {}, '', '5foo.bar', 'foo.6bar', '.foo.bar', 'bar..baz', '$foo.bar', 'foo$.ba$r', + None, '', '5foo.bar', 'foo.6bar', '.foo.bar', 'bar..baz', '$foo.bar', 'foo$.ba$r', 'org.mpris.MediaPlayer2.google-play-desktop-player' ] @@ -44,7 +44,7 @@ def test_interface_name_validator(): def test_member_name_validator(): valid_members = ['foo', 'FooBar', 'Bat_Baz69', 'foo-bar'] - invalid_members = [None, {}, '', 'foo.bar', '5foo', 'foo$bar'] + invalid_members = [None, '', 'foo.bar', '5foo', 'foo$bar'] for member in valid_members: assert is_member_name_valid(member), f'member name should be valid: "{member}"'