diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d679b82..277bd466 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,15 @@ Sections ### Developers --> +## [4.1.0] - 2021-08-22 + +### Added +- Add support for saving permissions when pairing. [#372](https://github.com/ikalchev/HAP-python/pull/372) +- Add accessory-level callbacks. [#373](https://github.com/ikalchev/HAP-python/pull/373) + +### Changed +- Increment the config version when the accessory changes. [#376](https://github.com/ikalchev/HAP-python/pull/376) + ## [4.0.0] - 2021-07-22 - Add support for HAP v 1.1. [#365](https://github.com/ikalchev/HAP-python/pull/365) diff --git a/pyhap/__init__.py b/pyhap/__init__.py index 86cba670..abf37c01 100644 --- a/pyhap/__init__.py +++ b/pyhap/__init__.py @@ -13,6 +13,7 @@ try: import base36 # noqa: F401 import pyqrcode # noqa: F401 + SUPPORT_QR_CODE = True except ImportError: pass diff --git a/pyhap/accessory.py b/pyhap/accessory.py index 57b8599c..85d10f8b 100644 --- a/pyhap/accessory.py +++ b/pyhap/accessory.py @@ -52,6 +52,7 @@ def __init__(self, driver, display_name, aid=None): self.driver = driver self.services = [] self.iid_manager = IIDManager() + self.setter_callback = None self.add_info_service() if aid == STANDALONE_AID: @@ -90,8 +91,7 @@ def add_info_service(self): def add_protocol_version_service(self): """Helper method to add the required HAP Protocol Information service""" serv_hap_proto_info = Service( - HAP_PROTOCOL_INFORMATION_SERVICE_UUID, - "HAPProtocolInformation" + HAP_PROTOCOL_INFORMATION_SERVICE_UUID, "HAPProtocolInformation" ) serv_hap_proto_info.add_characteristic(self.driver.loader.get_char("Version")) serv_hap_proto_info.configure_char("Version", value=HAP_PROTOCOL_VERSION) diff --git a/pyhap/accessory_driver.py b/pyhap/accessory_driver.py index 686796c0..170f12a8 100644 --- a/pyhap/accessory_driver.py +++ b/pyhap/accessory_driver.py @@ -45,7 +45,6 @@ HAP_REPR_PID, HAP_REPR_STATUS, HAP_REPR_VALUE, - MAX_CONFIG_VERSION, STANDALONE_AID, ) from pyhap.encoder import AccessoryEncoder @@ -69,6 +68,55 @@ DASH_REGEX = re.compile(r"[-]+") +def _wrap_char_setter(char, value, client_addr): + """Process an characteristic setter callback trapping and logging all exceptions.""" + try: + char.client_update_value(value, client_addr) + except Exception: # pylint: disable=broad-except + logger.exception( + "%s: Error while setting characteristic %s to %s", + client_addr, + char.display_name, + value, + ) + return HAP_SERVER_STATUS.SERVICE_COMMUNICATION_FAILURE + return HAP_SERVER_STATUS.SUCCESS + + +def _wrap_acc_setter(acc, updates_by_service, client_addr): + """Process an accessory setter callback trapping and logging all exceptions.""" + try: + acc.setter_callback(updates_by_service) + except Exception: # pylint: disable=broad-except + logger.exception( + "%s: Error while setting characteristics to %s for the %s accessory", + updates_by_service, + client_addr, + acc, + ) + return HAP_SERVER_STATUS.SERVICE_COMMUNICATION_FAILURE + return HAP_SERVER_STATUS.SUCCESS + + +def _wrap_service_setter(service, chars, client_addr): + """Process a service setter callback trapping and logging all exceptions.""" + # Ideally this would pass the chars as is without converting + # them to the display_name, but that would break existing + # consumers of the data. + service_chars = {char.display_name: value for char, value in chars.items()} + try: + service.setter_callback(service_chars) + except Exception: # pylint: disable=broad-except + logger.exception( + "%s: Error while setting characteristics to %s for the %s service", + service_chars, + client_addr, + service.display_name, + ) + return HAP_SERVER_STATUS.SERVICE_COMMUNICATION_FAILURE + return HAP_SERVER_STATUS.SUCCESS + + class AccessoryMDNSServiceInfo(ServiceInfo): """A mDNS service info representation of an accessory.""" @@ -265,8 +313,10 @@ def start(self): """ try: logger.info("Starting the event loop") - if threading.current_thread() is threading.main_thread() \ - and os.name != "nt": + if ( + threading.current_thread() is threading.main_thread() + and os.name != "nt" + ): logger.debug("Setting child watcher") watcher = asyncio.SafeChildWatcher() watcher.attach_loop(self.loop) @@ -325,6 +375,13 @@ async def async_start(self): logger.debug("Starting server.") await self.http_server.async_start(self.loop) + # Update the hash of the accessories + # in case the config version needs to be + # incremented to tell iOS to drop the cache + # for /accessories + if self.state.set_accessories_hash(self.accessories_hash): + self.async_persist() + # Advertise the accessory as a mDNS service. logger.debug("Starting mDNS.") self.mdns_service_info = AccessoryMDNSServiceInfo(self.accessory, self.state) @@ -519,7 +576,9 @@ def async_send_event(self, topic, data, sender_client_addr, immediate): client_addr, ) continue - logger.debug("Sending event to client: %s, immediate: %s", client_addr, immediate) + logger.debug( + "Sending event to client: %s, immediate: %s", client_addr, immediate + ) pushed = self.http_server.push_event(data, client_addr, immediate) if not pushed: logger.debug( @@ -538,9 +597,7 @@ def config_changed(self): restart. Also, updates the mDNS advertisement, so that iOS clients know they need to fetch new data. """ - self.state.config_version += 1 - if self.state.config_version > MAX_CONFIG_VERSION: - self.state.config_version = 1 + self.state.increment_config_version() self.persist() self.update_advertisement() @@ -589,11 +646,11 @@ def load(self): Must run in executor. """ - with open(self.persist_file, "r") as file_handle: + with open(self.persist_file, "r", encoding="utf8") as file_handle: self.encoder.load_into(file_handle, self.state) @callback - def pair(self, client_uuid, client_public): + def pair(self, client_uuid, client_public, client_permissions): """Called when a client has paired with the accessory. Persist the new accessory state. @@ -604,11 +661,14 @@ def pair(self, client_uuid, client_public): :param client_public: The client's public key. :type client_public: bytes + :param client_permissions: The client's permissions. + :type client_permissions: bytes (int) + :return: Whether the pairing is successful. :rtype: bool """ logger.info("Paired with %s.", client_uuid) - self.state.add_paired_client(client_uuid, client_public) + self.state.add_paired_client(client_uuid, client_public, client_permissions) self.async_persist() return True @@ -652,6 +712,13 @@ def setup_srp_verifier(self): verifier = SrpServer(ctx, b"Pair-Setup", self.state.pincode) self.srp_verifier = verifier + @property + def accessories_hash(self): + """Hash the get_accessories response to track configuration changes.""" + return hashlib.sha512( + util.to_sorted_hap_json(self.get_accessories()) + ).hexdigest() + def get_accessories(self): """Returns the accessory in HAP format. @@ -758,7 +825,7 @@ def set_characteristics(self, chars_query, client_addr): :type chars_query: dict """ # TODO: Add support for chars that do no support notifications. - accessory_callbacks = {} + updates = {} setter_results = {} had_error = False expired = False @@ -771,11 +838,10 @@ def set_characteristics(self, chars_query, client_addr): for cq in chars_query[HAP_REPR_CHARS]: aid, iid = cq[HAP_REPR_AID], cq[HAP_REPR_IID] - result = setter_results.setdefault(aid, {}) - char = self.accessory.get_characteristic(aid, iid) + setter_results.setdefault(aid, {}) if expired: - result[iid] = HAP_SERVER_STATUS.INVALID_VALUE_IN_REQUEST + setter_results[aid][iid] = HAP_SERVER_STATUS.INVALID_VALUE_IN_REQUEST had_error = True continue @@ -792,62 +858,50 @@ def set_characteristics(self, chars_query, client_addr): if HAP_REPR_VALUE not in cq: continue - value = cq[HAP_REPR_VALUE] + updates.setdefault(aid, {})[iid] = cq[HAP_REPR_VALUE] - try: - char.client_update_value(value, client_addr) - except Exception: # pylint: disable=broad-except - logger.exception( - "%s: Error while setting characteristic %s to %s", - client_addr, - char.display_name, - value, - ) - result[iid] = HAP_SERVER_STATUS.SERVICE_COMMUNICATION_FAILURE - had_error = True + for aid, new_iid_values in updates.items(): + if self.accessory.aid == aid: + acc = self.accessory else: - result[iid] = HAP_SERVER_STATUS.SUCCESS - - # For some services we want to send all the char value - # changes at once. This resolves an issue where we send - # ON and then BRIGHTNESS and the light would go to 100% - # and then dim to the brightness because each callback - # would only send one char at a time. - if not char.service or not char.service.setter_callback: - continue - - services = accessory_callbacks.setdefault(aid, {}) + acc = self.accessory.accessories.get(aid) - if char.service.display_name not in services: - services[char.service.display_name] = { - SERVICE_CALLBACK: char.service.setter_callback, - SERVICE_CHARS: {}, - SERVICE_IIDS: [], - } + updates_by_service = {} + char_to_iid = {} + for iid, value in new_iid_values.items(): + # Characteristic level setter callbacks + char = acc.get_characteristic(aid, iid) - service_data = services[char.service.display_name] - service_data[SERVICE_CHARS][char.display_name] = value - service_data[SERVICE_IIDS].append(iid) - - for aid, services in accessory_callbacks.items(): - for service_name, service_data in services.items(): - try: - service_data[SERVICE_CALLBACK](service_data[SERVICE_CHARS]) - except Exception: # pylint: disable=broad-except - logger.exception( - "%s: Error while setting characteristics to %s for the %s service", - service_data[SERVICE_CHARS], - client_addr, - service_name, - ) - set_result = HAP_SERVER_STATUS.SERVICE_COMMUNICATION_FAILURE + set_result = _wrap_char_setter(char, value, client_addr) + if set_result != HAP_SERVER_STATUS.SUCCESS: had_error = True - else: - set_result = HAP_SERVER_STATUS.SUCCESS - - for iid in service_data[SERVICE_IIDS]: + setter_results[aid][iid] = set_result + + if not char.service or ( + not acc.setter_callback and not char.service.setter_callback + ): + continue + char_to_iid[char] = iid + updates_by_service.setdefault(char.service, {}).update({char: value}) + + # Accessory level setter callbacks + if acc.setter_callback: + set_result = _wrap_acc_setter(acc, updates_by_service, client_addr) + if set_result != HAP_SERVER_STATUS.SUCCESS: + had_error = True + for iid in updates[aid]: setter_results[aid][iid] = set_result + # Service level setter callbacks + for service, chars in updates_by_service.items(): + if not service.setter_callback: + continue + set_result = _wrap_service_setter(service, chars, client_addr) + if set_result != HAP_SERVER_STATUS.SUCCESS: + had_error = True + for char in chars: + setter_results[aid][char_to_iid[char]] = set_result + if not had_error: return None @@ -880,7 +934,9 @@ def prepare(self, prepare_query, client_addr): try: ttl = prepare_query[HAP_REPR_TTL] pid = prepare_query[HAP_REPR_PID] - self.prepared_writes.setdefault(client_addr, {})[pid] = time.time() + (ttl / 1000) + self.prepared_writes.setdefault(client_addr, {})[pid] = time.time() + ( + ttl / 1000 + ) except (KeyError, ValueError): return {HAP_REPR_STATUS: HAP_SERVER_STATUS.INVALID_VALUE_IN_REQUEST} diff --git a/pyhap/const.py b/pyhap/const.py index d54edb4a..63dff979 100644 --- a/pyhap/const.py +++ b/pyhap/const.py @@ -1,6 +1,6 @@ """This module contains constants used by other modules.""" MAJOR_VERSION = 4 -MINOR_VERSION = 0 +MINOR_VERSION = 1 PATCH_VERSION = 0 __short_version__ = "{}.{}".format(MAJOR_VERSION, MINOR_VERSION) __version__ = "{}.{}".format(__short_version__, PATCH_VERSION) @@ -12,7 +12,7 @@ STANDALONE_AID = 1 # Standalone accessory ID (i.e. not bridged) # ### Default values ### -DEFAULT_CONFIG_VERSION = 2 +DEFAULT_CONFIG_VERSION = 1 DEFAULT_PORT = 51827 # ### Configuration version ### @@ -97,3 +97,12 @@ class HAP_SERVER_STATUS: RESOURCE_DOES_NOT_EXIST = -70409 INVALID_VALUE_IN_REQUEST = -70410 INSUFFICIENT_AUTHORIZATION = -70411 + + +class HAP_PERMISSIONS: + USER = b"\x00" + ADMIN = b"\x01" + + +# Client properties +CLIENT_PROP_PERMS = "permissions" diff --git a/pyhap/encoder.py b/pyhap/encoder.py index 9170c3e9..a6a1a0a0 100644 --- a/pyhap/encoder.py +++ b/pyhap/encoder.py @@ -9,6 +9,8 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ed25519 +from .const import CLIENT_PROP_PERMS + class AccessoryEncoder: """This class defines the Accessory encoder interface. @@ -32,6 +34,7 @@ class AccessoryEncoder: - UUID and public key of all paired clients. - MAC address. - Config version - ok, this is debatable, but it retains the consistency. + - Accessories Hash The default implementation persists the above properties. @@ -50,14 +53,20 @@ def persist(fp, state): - Public and private key. - UUID and public key of paired clients. - Config version. + - Accessories Hash """ paired_clients = { str(client): bytes.hex(key) for client, key in state.paired_clients.items() } + client_properties = { + str(client): props for client, props in state.client_properties.items() + } config_state = { "mac": state.mac, "config_version": state.config_version, "paired_clients": paired_clients, + "client_properties": client_properties, + "accessories_hash": state.accessories_hash, "private_key": bytes.hex( state.private_key.private_bytes( encoding=serialization.Encoding.Raw, @@ -82,7 +91,20 @@ def load_into(fp, state): """ loaded = json.load(fp) state.mac = loaded["mac"] + state.accessories_hash = loaded.get("accessories_hash") state.config_version = loaded["config_version"] + if "client_properties" in loaded: + state.client_properties = { + uuid.UUID(client): props + for client, props in loaded["client_properties"].items() + } + else: + # If "client_properties" does not exist, everyone + # before that was paired as an admin + state.client_properties = { + uuid.UUID(client): {CLIENT_PROP_PERMS: 1} + for client in loaded["paired_clients"] + } state.paired_clients = { uuid.UUID(client): bytes.fromhex(key) for client, key in loaded["paired_clients"].items() diff --git a/pyhap/hap_event.py b/pyhap/hap_event.py index e2a6cd71..37bf2134 100644 --- a/pyhap/hap_event.py +++ b/pyhap/hap_event.py @@ -3,7 +3,6 @@ from .const import HAP_REPR_CHARS from .util import to_hap_json - EVENT_MSG_STUB = ( b"EVENT/1.0 200 OK\r\n" b"Content-Type: application/hap+json\r\n" diff --git a/pyhap/hap_handler.py b/pyhap/hap_handler.py index a9dc202b..d65f0c05 100644 --- a/pyhap/hap_handler.py +++ b/pyhap/hap_handler.py @@ -11,13 +11,13 @@ from cryptography.exceptions import InvalidSignature, InvalidTag from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import ed25519 -from cryptography.hazmat.primitives.asymmetric import x25519 +from cryptography.hazmat.primitives.asymmetric import ed25519, x25519 from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 from pyhap import tlv from pyhap.const import ( CATEGORY_BRIDGE, + HAP_PERMISSIONS, HAP_REPR_CHARS, HAP_REPR_STATUS, HAP_SERVER_STATUS, @@ -27,7 +27,6 @@ from .hap_crypto import hap_hkdf, pad_tls_nonce from .util import to_hap_json - # iOS will terminate the connection if it does not respond within # 10 seconds, so we only allow 9 seconds to avoid this. RESPONSE_TIMEOUT = 9 @@ -85,11 +84,6 @@ class HAP_TLV_TAGS: PERMISSIONS = b"\x0B" -class HAP_PERMISSIONS: - USER = b"\x00" - ADMIN = b"\x01" - - class UnprivilegedRequestException(Exception): pass @@ -145,6 +139,7 @@ def __init__(self, accessory_handler, client_address): self.enc_context = None self.client_address = client_address self.is_encrypted = False + self.client_uuid = None self.path = None self.command = None @@ -421,7 +416,9 @@ def _pairing_five(self, client_username, client_ltpk, encryption_key): aead_message = bytes(cipher.encrypt(self.PAIRING_5_NONCE, bytes(message), b"")) client_uuid = uuid.UUID(str(client_username, "utf-8")) - should_confirm = self.accessory_handler.pair(client_uuid, client_ltpk) + should_confirm = self.accessory_handler.pair( + client_uuid, client_ltpk, HAP_PERMISSIONS.ADMIN + ) if not should_confirm: self.send_response_with_status( @@ -568,6 +565,7 @@ def _pair_verify_two(self, tlv_objects): self._send_tlv_pairing_response(data) self.response.shared_key = self.enc_context["shared_key"] self.is_encrypted = True + self.client_uuid = client_uuid del self.enc_context def handle_accessories(self): @@ -649,7 +647,8 @@ def handle_prepare(self): def handle_pairings(self): """Handles a client request to update or remove a pairing.""" - if not self.is_encrypted: + # Must be an admin to handle pairings + if not self.is_encrypted or not self.state.is_admin(self.client_uuid): self._send_authentication_error_tlv_response(HAP_TLV_STATES.M2) return @@ -671,8 +670,11 @@ def _handle_add_pairing(self, tlv_objects): logger.debug("%s: Adding client pairing.", self.client_address) client_username = tlv_objects[HAP_TLV_TAGS.USERNAME] client_public = tlv_objects[HAP_TLV_TAGS.PUBLIC_KEY] + permissions = tlv_objects[HAP_TLV_TAGS.PERMISSIONS] client_uuid = uuid.UUID(str(client_username, "utf-8")) - should_confirm = self.accessory_handler.pair(client_uuid, client_public) + should_confirm = self.accessory_handler.pair( + client_uuid, client_public, permissions + ) if not should_confirm: self._send_authentication_error_tlv_response(HAP_TLV_STATES.M2) return @@ -706,6 +708,7 @@ def _handle_list_pairings(self): logger.debug("%s: list pairings", self.client_address) response = [HAP_TLV_TAGS.SEQUENCE_NUM, HAP_TLV_STATES.M2] for client_uuid, client_public in self.state.paired_clients.items(): + admin = self.state.is_admin(client_uuid) response.extend( [ HAP_TLV_TAGS.USERNAME, @@ -713,7 +716,7 @@ def _handle_list_pairings(self): HAP_TLV_TAGS.PUBLIC_KEY, client_public, HAP_TLV_TAGS.PERMISSIONS, - HAP_PERMISSIONS.ADMIN, + HAP_PERMISSIONS.ADMIN if admin else HAP_PERMISSIONS.USER, ] ) diff --git a/pyhap/hap_protocol.py b/pyhap/hap_protocol.py index 0ddf6615..9490a2ee 100644 --- a/pyhap/hap_protocol.py +++ b/pyhap/hap_protocol.py @@ -50,11 +50,14 @@ def __init__(self, loop, connections, accessory_driver) -> None: self._event_timer = None self._event_queue = [] + self.start_time = None + def connection_lost(self, exc: Exception) -> None: """Handle connection lost.""" logger.debug( - "%s: Connection lost to %s: %s", + "%s (%s): Connection lost to %s: %s", self.peername, + self.handler.client_uuid, self.accessory_driver.accessory.display_name, exc, ) @@ -84,10 +87,20 @@ def write(self, data: bytes) -> None: self.last_activity = time.time() if self.hap_crypto: result = self.hap_crypto.encrypt(data) - logger.debug("%s: Send encrypted: %s", self.peername, data) + logger.debug( + "%s (%s): Send encrypted: %s", + self.peername, + self.handler.client_uuid, + data, + ) self.transport.write(result) else: - logger.debug("%s: Send unencrypted: %s", self.peername, data) + logger.debug( + "%s (%s): Send unencrypted: %s", + self.peername, + self.handler.client_uuid, + data, + ) self.transport.write(data) def close(self) -> None: @@ -152,18 +165,31 @@ def data_received(self, data: bytes) -> None: unencrypted_data = self.hap_crypto.decrypt() except InvalidTag as ex: logger.debug( - "%s: Decrypt failed, closing connection: %s", self.peername, ex + "%s (%s): Decrypt failed, closing connection: %s", + self.peername, + self.handler.client_uuid, + ex, ) self.close() return if unencrypted_data == b"": logger.debug("No decryptable data") return - logger.debug("%s: Recv decrypted: %s", self.peername, unencrypted_data) + logger.debug( + "%s (%s): Recv decrypted: %s", + self.peername, + self.handler.client_uuid, + unencrypted_data, + ) self.conn.receive_data(unencrypted_data) else: self.conn.receive_data(data) - logger.debug("%s: Recv unencrypted: %s", self.peername, data) + logger.debug( + "%s (%s): Recv unencrypted: %s", + self.peername, + self.handler.client_uuid, + data, + ) self._process_events() def _process_events(self): @@ -201,7 +227,9 @@ def _event_queue_with_active_subscriptions(self): def _process_one_event(self) -> bool: """Process one http event.""" event = self.conn.next_event() - logger.debug("%s: h11 Event: %s", self.peername, event) + logger.debug( + "%s (%s): h11 Event: %s", self.peername, self.handler.client_uuid, event + ) if event in (h11.NEED_DATA, h11.ConnectionClosed): return False @@ -254,13 +282,17 @@ def _handle_response_ready(self, task: asyncio.Task) -> None: response.body = task.result() except Exception as ex: # pylint: disable=broad-except logger.debug( - "%s: exception during delayed response", self.peername, exc_info=ex + "%s (%s): exception during delayed response", + self.peername, + self.handler.client_uuid, + exc_info=ex, ) response = self.handler.generic_failure_response() if self.transport.is_closing(): logger.debug( - "%s: delayed response not sent as the transport as closed.", + "%s (%s): delayed response not sent as the transport as closed.", self.peername, + self.handler.client_uuid, ) return self.send_response(response) @@ -268,9 +300,10 @@ def _handle_response_ready(self, task: asyncio.Task) -> None: def _handle_invalid_conn_state(self, message): """Log invalid state and close.""" logger.debug( - "%s: Invalid state: %s: close the client socket", - message, + "%s (%s): Invalid state: %s: close the client socket", self.peername, + self.handler.client_uuid, + message, ) self.close() return False diff --git a/pyhap/loader.py b/pyhap/loader.py index c5667868..7a89ce94 100644 --- a/pyhap/loader.py +++ b/pyhap/loader.py @@ -33,7 +33,7 @@ def __init__(self, path_char=CHARACTERISTICS_FILE, path_service=SERVICES_FILE): @staticmethod def _read_file(path): """Read file and return a dict.""" - with open(path, "r") as file: + with open(path, "r", encoding="utf8") as file: return json.load(file) def get_char(self, name): diff --git a/pyhap/state.py b/pyhap/state.py index 0c279601..c44dc4f7 100644 --- a/pyhap/state.py +++ b/pyhap/state.py @@ -2,7 +2,14 @@ from cryptography.hazmat.primitives.asymmetric import ed25519 from pyhap import util -from pyhap.const import DEFAULT_CONFIG_VERSION, DEFAULT_PORT +from pyhap.const import ( + CLIENT_PROP_PERMS, + DEFAULT_CONFIG_VERSION, + DEFAULT_PORT, + MAX_CONFIG_VERSION, +) + +ADMIN_BIT = 0x01 class State: @@ -24,9 +31,11 @@ def __init__(self, *, address=None, mac=None, pincode=None, port=None): self.config_version = DEFAULT_CONFIG_VERSION self.paired_clients = {} + self.client_properties = {} self.private_key = ed25519.Ed25519PrivateKey.generate() self.public_key = self.private_key.public_key() + self.accessories_hash = None # ### Pairing ### @property @@ -34,7 +43,13 @@ def paired(self): """Return if main accessory is currently paired.""" return len(self.paired_clients) > 0 - def add_paired_client(self, client_uuid, client_public): + def is_admin(self, client_uuid): + """Check if a paired client is an admin.""" + if client_uuid not in self.client_properties: + return False + return bool(self.client_properties[client_uuid][CLIENT_PROP_PERMS] & ADMIN_BIT) + + def add_paired_client(self, client_uuid, client_public, perms): """Add a given client to dictionary of paired clients. :param client_uuid: The client's UUID. @@ -45,6 +60,7 @@ def add_paired_client(self, client_uuid, client_public): :type client_public: bytes """ self.paired_clients[client_uuid] = client_public + self.client_properties[client_uuid] = {CLIENT_PROP_PERMS: ord(perms)} def remove_paired_client(self, client_uuid): """Remove a given client from dictionary of paired clients. @@ -53,3 +69,23 @@ def remove_paired_client(self, client_uuid): :type client_uuid: uuid.UUID """ self.paired_clients.pop(client_uuid) + self.client_properties.pop(client_uuid) + + # All pairings must be removed when the last admin is removed + if not any(self.is_admin(client_uuid) for client_uuid in self.paired_clients): + self.paired_clients.clear() + self.client_properties.clear() + + def set_accessories_hash(self, accessories_hash): + """Set the accessories hash and increment the config version if needed.""" + if self.accessories_hash == accessories_hash: + return False + self.accessories_hash = accessories_hash + self.increment_config_version() + return True + + def increment_config_version(self): + """Increment the config version.""" + self.config_version += 1 + if self.config_version > MAX_CONFIG_VERSION: + self.config_version = 1 diff --git a/pyhap/util.py b/pyhap/util.py index 1696a03c..ff1f22e0 100644 --- a/pyhap/util.py +++ b/pyhap/util.py @@ -61,7 +61,7 @@ def long_to_bytes(n): :return: ``long int`` in ``bytes`` format. :rtype: bytes """ - byteList = list() + byteList = [] x = 0 off = 0 while x != n: @@ -158,3 +158,8 @@ def hap_type_to_uuid(hap_type): def to_hap_json(dump_obj): """Convert an object to HAP json.""" return json.dumps(dump_obj, separators=(",", ":")).encode("utf-8") + + +def to_sorted_hap_json(dump_obj): + """Convert an object to sorted HAP json.""" + return json.dumps(dump_obj, sort_keys=True, separators=(",", ":")).encode("utf-8") diff --git a/tests/test_accessory_driver.py b/tests/test_accessory_driver.py index 69f0bb2c..ec87fc1d 100644 --- a/tests/test_accessory_driver.py +++ b/tests/test_accessory_driver.py @@ -436,6 +436,213 @@ def setup_message(self): driver.start() +def test_accessory_level_callbacks(driver): + bridge = Bridge(driver, "mybridge") + acc = Accessory(driver, "TestAcc", aid=2) + acc2 = UnavailableAccessory(driver, "TestAcc2", aid=3) + + service = Service(uuid1(), "Lightbulb") + char_on = Characteristic("On", uuid1(), CHAR_PROPS) + char_brightness = Characteristic("Brightness", uuid1(), CHAR_PROPS) + + service.add_characteristic(char_on) + service.add_characteristic(char_brightness) + + switch_service = Service(uuid1(), "Switch") + char_switch_on = Characteristic("On", uuid1(), CHAR_PROPS) + switch_service.add_characteristic(char_switch_on) + + mock_callback = MagicMock() + acc.setter_callback = mock_callback + + acc.add_service(service) + acc.add_service(switch_service) + bridge.add_accessory(acc) + + service2 = Service(uuid1(), "Lightbulb") + char_on2 = Characteristic("On", uuid1(), CHAR_PROPS) + char_brightness2 = Characteristic("Brightness", uuid1(), CHAR_PROPS) + + service2.add_characteristic(char_on2) + service2.add_characteristic(char_brightness2) + + mock_callback2 = MagicMock() + acc2.setter_callback = mock_callback2 + + acc2.add_service(service2) + bridge.add_accessory(acc2) + + char_switch_on_iid = char_switch_on.to_HAP()[HAP_REPR_IID] + char_on_iid = char_on.to_HAP()[HAP_REPR_IID] + char_brightness_iid = char_brightness.to_HAP()[HAP_REPR_IID] + char_on2_iid = char_on2.to_HAP()[HAP_REPR_IID] + char_brightness2_iid = char_brightness2.to_HAP()[HAP_REPR_IID] + + driver.add_accessory(bridge) + + response = driver.set_characteristics( + { + HAP_REPR_CHARS: [ + { + HAP_REPR_AID: acc.aid, + HAP_REPR_IID: char_on_iid, + HAP_REPR_VALUE: True, + }, + { + HAP_REPR_AID: acc.aid, + HAP_REPR_IID: char_switch_on_iid, + HAP_REPR_VALUE: True, + }, + { + HAP_REPR_AID: acc.aid, + HAP_REPR_IID: char_brightness_iid, + HAP_REPR_VALUE: 88, + }, + { + HAP_REPR_AID: acc2.aid, + HAP_REPR_IID: char_on2_iid, + HAP_REPR_VALUE: True, + }, + { + HAP_REPR_AID: acc2.aid, + HAP_REPR_IID: char_brightness2_iid, + HAP_REPR_VALUE: 12, + }, + ] + }, + "mock_addr", + ) + assert response is None + + mock_callback.assert_called_with( + { + service: {char_on: True, char_brightness: 88}, + switch_service: {char_switch_on: True}, + } + ) + mock_callback2.assert_called_with( + {service2: {char_on2: True, char_brightness2: 12}} + ) + + +def test_accessory_level_callbacks_with_a_failure(driver): + bridge = Bridge(driver, "mybridge") + acc = Accessory(driver, "TestAcc", aid=2) + acc2 = UnavailableAccessory(driver, "TestAcc2", aid=3) + + service = Service(uuid1(), "Lightbulb") + char_on = Characteristic("On", uuid1(), CHAR_PROPS) + char_brightness = Characteristic("Brightness", uuid1(), CHAR_PROPS) + + service.add_characteristic(char_on) + service.add_characteristic(char_brightness) + + switch_service = Service(uuid1(), "Switch") + char_switch_on = Characteristic("On", uuid1(), CHAR_PROPS) + switch_service.add_characteristic(char_switch_on) + + mock_callback = MagicMock() + acc.setter_callback = mock_callback + + acc.add_service(service) + acc.add_service(switch_service) + bridge.add_accessory(acc) + + service2 = Service(uuid1(), "Lightbulb") + char_on2 = Characteristic("On", uuid1(), CHAR_PROPS) + char_brightness2 = Characteristic("Brightness", uuid1(), CHAR_PROPS) + + service2.add_characteristic(char_on2) + service2.add_characteristic(char_brightness2) + + mock_callback2 = MagicMock(side_effect=OSError) + acc2.setter_callback = mock_callback2 + + acc2.add_service(service2) + bridge.add_accessory(acc2) + + char_switch_on_iid = char_switch_on.to_HAP()[HAP_REPR_IID] + char_on_iid = char_on.to_HAP()[HAP_REPR_IID] + char_brightness_iid = char_brightness.to_HAP()[HAP_REPR_IID] + char_on2_iid = char_on2.to_HAP()[HAP_REPR_IID] + char_brightness2_iid = char_brightness2.to_HAP()[HAP_REPR_IID] + + driver.add_accessory(bridge) + + response = driver.set_characteristics( + { + HAP_REPR_CHARS: [ + { + HAP_REPR_AID: acc.aid, + HAP_REPR_IID: char_on_iid, + HAP_REPR_VALUE: True, + }, + { + HAP_REPR_AID: acc.aid, + HAP_REPR_IID: char_switch_on_iid, + HAP_REPR_VALUE: True, + }, + { + HAP_REPR_AID: acc.aid, + HAP_REPR_IID: char_brightness_iid, + HAP_REPR_VALUE: 88, + }, + { + HAP_REPR_AID: acc2.aid, + HAP_REPR_IID: char_on2_iid, + HAP_REPR_VALUE: True, + }, + { + HAP_REPR_AID: acc2.aid, + HAP_REPR_IID: char_brightness2_iid, + HAP_REPR_VALUE: 12, + }, + ] + }, + "mock_addr", + ) + + mock_callback.assert_called_with( + { + service: {char_on: True, char_brightness: 88}, + switch_service: {char_switch_on: True}, + } + ) + mock_callback2.assert_called_with( + {service2: {char_on2: True, char_brightness2: 12}} + ) + + assert response == { + HAP_REPR_CHARS: [ + { + HAP_REPR_AID: acc.aid, + HAP_REPR_IID: char_on_iid, + HAP_REPR_STATUS: HAP_SERVER_STATUS.SUCCESS, + }, + { + HAP_REPR_AID: acc.aid, + HAP_REPR_IID: char_switch_on_iid, + HAP_REPR_STATUS: HAP_SERVER_STATUS.SUCCESS, + }, + { + HAP_REPR_AID: acc.aid, + HAP_REPR_IID: char_brightness_iid, + HAP_REPR_STATUS: HAP_SERVER_STATUS.SUCCESS, + }, + { + HAP_REPR_AID: acc2.aid, + HAP_REPR_IID: char_on2_iid, + HAP_REPR_STATUS: HAP_SERVER_STATUS.SERVICE_COMMUNICATION_FAILURE, + }, + { + HAP_REPR_AID: acc2.aid, + HAP_REPR_IID: char_brightness2_iid, + HAP_REPR_STATUS: HAP_SERVER_STATUS.SERVICE_COMMUNICATION_FAILURE, + }, + ] + } + + @pytest.mark.asyncio async def test_start_stop_sync_acc(async_zeroconf): with patch( @@ -462,6 +669,7 @@ def setup_message(self): driver.add_accessory(acc) driver.start_service() await run_event.wait() + assert driver.state.config_version == 2 assert not driver.loop.is_closed() await driver.async_stop() assert not driver.loop.is_closed() @@ -497,7 +705,35 @@ def setup_message(self): driver.start_service() await asyncio.sleep(0) await run_event.wait() + assert driver.state.config_version == 2 + assert not driver.loop.is_closed() + await driver.async_stop() + assert not driver.loop.is_closed() + + run_event.clear() + driver.start_service() + await asyncio.sleep(0) + await run_event.wait() + assert driver.state.config_version == 2 + await driver.async_stop() + assert not driver.loop.is_closed() + acc.add_preload_service("GarageDoorOpener") + + # Adding a new service should increment the config version + run_event.clear() + driver.start_service() + await asyncio.sleep(0) + await run_event.wait() + assert driver.state.config_version == 3 + await driver.async_stop() assert not driver.loop.is_closed() + + # But only once + run_event.clear() + driver.start_service() + await asyncio.sleep(0) + await run_event.wait() + assert driver.state.config_version == 3 await driver.async_stop() assert not driver.loop.is_closed() @@ -609,7 +845,7 @@ def test_mdns_service_info(driver): "md": "Test Accessory", "pv": "1.1", "id": "00:00:00:00:00:00", - "c#": "2", + "c#": "1", "s#": "1", "ff": "0", "ci": "1", diff --git a/tests/test_encoder.py b/tests/test_encoder.py index ca6d0c4a..ff2a4a2a 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -1,4 +1,5 @@ """Tests for pyhap.encoder.""" +import json import tempfile import uuid @@ -6,6 +7,7 @@ from cryptography.hazmat.primitives.asymmetric import ed25519 from pyhap import encoder +from pyhap.const import HAP_PERMISSIONS from pyhap.state import State from pyhap.util import generate_mac @@ -18,14 +20,26 @@ def test_persist_and_load(): _pk = ed25519.Ed25519PrivateKey.generate() sample_client_pk = _pk.public_key() state = State(mac=mac) + admin_client_uuid = uuid.uuid1() state.add_paired_client( - uuid.uuid1(), + admin_client_uuid, sample_client_pk.public_bytes( encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw, ), + HAP_PERMISSIONS.ADMIN, ) - + assert state.is_admin(admin_client_uuid) + user_client_uuid = uuid.uuid1() + state.add_paired_client( + user_client_uuid, + sample_client_pk.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ), + HAP_PERMISSIONS.USER, + ) + assert not state.is_admin(user_client_uuid) config_loaded = State() config_loaded.config_version += 2 # change the default state. enc = encoder.AccessoryEncoder() @@ -53,3 +67,51 @@ def test_persist_and_load(): ) assert state.config_version == config_loaded.config_version assert state.paired_clients == config_loaded.paired_clients + assert state.client_properties == config_loaded.client_properties + + +def test_migration_to_include_client_properties(): + """Verify we build client properties if its missing since it was not present in older versions.""" + mac = generate_mac() + _pk = ed25519.Ed25519PrivateKey.generate() + sample_client_pk = _pk.public_key() + state = State(mac=mac) + admin_client_uuid = uuid.uuid1() + state.add_paired_client( + admin_client_uuid, + sample_client_pk.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ), + HAP_PERMISSIONS.ADMIN, + ) + assert state.is_admin(admin_client_uuid) + user_client_uuid = uuid.uuid1() + state.add_paired_client( + user_client_uuid, + sample_client_pk.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ), + HAP_PERMISSIONS.USER, + ) + assert not state.is_admin(user_client_uuid) + + config_loaded = State() + config_loaded.config_version += 2 # change the default state. + enc = encoder.AccessoryEncoder() + with tempfile.TemporaryFile(mode="r+") as fp: + enc.persist(fp, state) + fp.seek(0) + loaded = json.load(fp) + fp.seek(0) + del loaded["client_properties"] + json.dump(loaded, fp) + fp.truncate() + fp.seek(0) + enc.load_into(fp, config_loaded) + + # When client_permissions are missing, all clients + # are imported as admins for backwards compatibility + assert config_loaded.is_admin(admin_client_uuid) + assert config_loaded.is_admin(user_client_uuid) diff --git a/tests/test_hap_handler.py b/tests/test_hap_handler.py index 2c0f820d..0225cfbd 100644 --- a/tests/test_hap_handler.py +++ b/tests/test_hap_handler.py @@ -1,18 +1,20 @@ """Tests for the HAPServerHandler.""" +import json from unittest.mock import patch from uuid import UUID -import json import pytest -from pyhap import hap_handler +from pyhap import hap_handler, tlv from pyhap.accessory import Accessory, Bridge from pyhap.characteristic import CharacteristicError -from pyhap import tlv +from pyhap.const import HAP_PERMISSIONS CLIENT_UUID = UUID("7d0d1ee9-46fe-4a56-a115-69df3f6860c1") +CLIENT2_UUID = UUID("7d0d1ee9-46fe-4a56-a115-69df3f6860c2") + PUBLIC_KEY = b"\x99\x98d%\x8c\xf6h\x06\xfa\x85\x9f\x90\x82\xf2\xe8\x18\x9f\xf8\xc75\x1f>~\xc32\xc1OC\x13\xbfH\xad" @@ -29,10 +31,7 @@ def test_list_pairings_unencrypted(driver): handler = hap_handler.HAPServerHandler(driver, "peername") handler.is_encrypted = False - driver.pair( - CLIENT_UUID, - PUBLIC_KEY, - ) + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) assert CLIENT_UUID in driver.state.paired_clients response = hap_handler.HAPResponse() @@ -56,10 +55,8 @@ def test_list_pairings(driver): handler = hap_handler.HAPServerHandler(driver, "peername") handler.is_encrypted = True - driver.pair( - CLIENT_UUID, - PUBLIC_KEY, - ) + handler.client_uuid = CLIENT_UUID + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) assert CLIENT_UUID in driver.state.paired_clients response = hap_handler.HAPResponse() @@ -79,32 +76,108 @@ def test_list_pairings(driver): } -def test_add_pairing(driver): +def test_add_pairing_admin(driver): """Verify an encrypted add pairing request.""" driver.add_accessory(Accessory(driver, "TestAcc")) handler = hap_handler.HAPServerHandler(driver, "peername") handler.is_encrypted = True + handler.client_uuid = CLIENT_UUID + assert driver.state.paired is False + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) + response = hap_handler.HAPResponse() handler.response = response handler.request_body = tlv.encode( hap_handler.HAP_TLV_TAGS.REQUEST_TYPE, hap_handler.HAP_TLV_STATES.M3, hap_handler.HAP_TLV_TAGS.USERNAME, - str(CLIENT_UUID).encode("utf-8"), + str(CLIENT2_UUID).encode("utf-8"), hap_handler.HAP_TLV_TAGS.PUBLIC_KEY, PUBLIC_KEY, hap_handler.HAP_TLV_TAGS.PERMISSIONS, hap_handler.HAP_PERMISSIONS.ADMIN, ) + handler.handle_pairings() + assert tlv.decode(response.body) == { + hap_handler.HAP_TLV_TAGS.SEQUENCE_NUM: hap_handler.HAP_TLV_STATES.M2 + } + assert driver.state.paired is True + assert CLIENT2_UUID in driver.state.paired_clients + assert driver.state.is_admin(CLIENT2_UUID) + + +def test_add_pairing_user(driver): + """Verify an encrypted add pairing request.""" + driver.add_accessory(Accessory(driver, "TestAcc")) + + handler = hap_handler.HAPServerHandler(driver, "peername") + handler.is_encrypted = True + handler.client_uuid = CLIENT_UUID assert driver.state.paired is False + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) + + response = hap_handler.HAPResponse() + handler.response = response + handler.request_body = tlv.encode( + hap_handler.HAP_TLV_TAGS.REQUEST_TYPE, + hap_handler.HAP_TLV_STATES.M3, + hap_handler.HAP_TLV_TAGS.USERNAME, + str(CLIENT2_UUID).encode("utf-8"), + hap_handler.HAP_TLV_TAGS.PUBLIC_KEY, + PUBLIC_KEY, + hap_handler.HAP_TLV_TAGS.PERMISSIONS, + hap_handler.HAP_PERMISSIONS.USER, + ) + handler.handle_pairings() + assert tlv.decode(response.body) == { + hap_handler.HAP_TLV_TAGS.SEQUENCE_NUM: hap_handler.HAP_TLV_STATES.M2 + } + assert driver.state.paired is True + assert CLIENT2_UUID in driver.state.paired_clients + assert not driver.state.is_admin(CLIENT2_UUID) + # Verify upgrade to admin + response = hap_handler.HAPResponse() + handler.response = response + handler.request_body = tlv.encode( + hap_handler.HAP_TLV_TAGS.REQUEST_TYPE, + hap_handler.HAP_TLV_STATES.M3, + hap_handler.HAP_TLV_TAGS.USERNAME, + str(CLIENT2_UUID).encode("utf-8"), + hap_handler.HAP_TLV_TAGS.PUBLIC_KEY, + PUBLIC_KEY, + hap_handler.HAP_TLV_TAGS.PERMISSIONS, + hap_handler.HAP_PERMISSIONS.ADMIN, + ) handler.handle_pairings() assert tlv.decode(response.body) == { hap_handler.HAP_TLV_TAGS.SEQUENCE_NUM: hap_handler.HAP_TLV_STATES.M2 } assert driver.state.paired is True - assert CLIENT_UUID in driver.state.paired_clients + assert CLIENT2_UUID in driver.state.paired_clients + assert driver.state.is_admin(CLIENT2_UUID) + + # Verify downgrade to normal user + response = hap_handler.HAPResponse() + handler.response = response + handler.request_body = tlv.encode( + hap_handler.HAP_TLV_TAGS.REQUEST_TYPE, + hap_handler.HAP_TLV_STATES.M3, + hap_handler.HAP_TLV_TAGS.USERNAME, + str(CLIENT2_UUID).encode("utf-8"), + hap_handler.HAP_TLV_TAGS.PUBLIC_KEY, + PUBLIC_KEY, + hap_handler.HAP_TLV_TAGS.PERMISSIONS, + hap_handler.HAP_PERMISSIONS.USER, + ) + handler.handle_pairings() + assert tlv.decode(response.body) == { + hap_handler.HAP_TLV_TAGS.SEQUENCE_NUM: hap_handler.HAP_TLV_STATES.M2 + } + assert driver.state.paired is True + assert CLIENT2_UUID in driver.state.paired_clients + assert not driver.state.is_admin(CLIENT2_UUID) def test_remove_pairing(driver): @@ -113,10 +186,11 @@ def test_remove_pairing(driver): handler = hap_handler.HAPServerHandler(driver, "peername") handler.is_encrypted = True - driver.pair( - CLIENT_UUID, - PUBLIC_KEY, - ) + handler.client_uuid = CLIENT_UUID + + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) + driver.pair(CLIENT2_UUID, PUBLIC_KEY, HAP_PERMISSIONS.USER) + assert driver.state.paired is True assert CLIENT_UUID in driver.state.paired_clients @@ -127,7 +201,7 @@ def test_remove_pairing(driver): hap_handler.HAP_TLV_TAGS.REQUEST_TYPE, hap_handler.HAP_TLV_STATES.M4, hap_handler.HAP_TLV_TAGS.USERNAME, - str(CLIENT_UUID).encode("utf-8"), + str(CLIENT2_UUID).encode("utf-8"), hap_handler.HAP_TLV_TAGS.PUBLIC_KEY, PUBLIC_KEY, ) @@ -135,8 +209,50 @@ def test_remove_pairing(driver): assert tlv.decode(response.body) == { hap_handler.HAP_TLV_TAGS.SEQUENCE_NUM: hap_handler.HAP_TLV_STATES.M2 } - assert CLIENT_UUID not in driver.state.paired_clients - assert driver.state.paired is False + assert CLIENT2_UUID not in driver.state.paired_clients + assert driver.state.paired is True + + # Now remove the last admin + response = hap_handler.HAPResponse() + handler.response = response + handler.request_body = tlv.encode( + hap_handler.HAP_TLV_TAGS.REQUEST_TYPE, + hap_handler.HAP_TLV_STATES.M4, + hap_handler.HAP_TLV_TAGS.USERNAME, + str(CLIENT_UUID).encode("utf-8"), + hap_handler.HAP_TLV_TAGS.PUBLIC_KEY, + PUBLIC_KEY, + ) + handler.handle_pairings() + assert tlv.decode(response.body) == { + hap_handler.HAP_TLV_TAGS.SEQUENCE_NUM: hap_handler.HAP_TLV_STATES.M2 + } + assert CLIENT_UUID not in driver.state.paired_clients + assert driver.state.paired is False + + +def test_non_admin_pairings_request(driver): + """Verify only admins can access pairings.""" + driver.add_accessory(Accessory(driver, "TestAcc")) + + handler = hap_handler.HAPServerHandler(driver, "peername") + handler.is_encrypted = True + handler.client_uuid = CLIENT_UUID + + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.USER) + assert CLIENT_UUID in driver.state.paired_clients + + response = hap_handler.HAPResponse() + handler.response = response + handler.request_body = tlv.encode( + hap_handler.HAP_TLV_TAGS.REQUEST_TYPE, hap_handler.HAP_TLV_STATES.M5 + ) + + handler.handle_pairings() + assert tlv.decode(response.body) == { + hap_handler.HAP_TLV_TAGS.SEQUENCE_NUM: hap_handler.HAP_TLV_STATES.M2, + hap_handler.HAP_TLV_TAGS.ERROR_CODE: hap_handler.HAP_TLV_ERRORS.AUTHENTICATION, + } def test_invalid_pairings_request(driver): @@ -145,10 +261,9 @@ def test_invalid_pairings_request(driver): handler = hap_handler.HAPServerHandler(driver, "peername") handler.is_encrypted = True - driver.pair( - CLIENT_UUID, - PUBLIC_KEY, - ) + handler.client_uuid = CLIENT_UUID + + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) assert CLIENT_UUID in driver.state.paired_clients response = hap_handler.HAPResponse() @@ -167,10 +282,7 @@ def test_pair_verify_one(driver): handler = hap_handler.HAPServerHandler(driver, "peername") handler.is_encrypted = False - driver.pair( - CLIENT_UUID, - PUBLIC_KEY, - ) + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) assert CLIENT_UUID in driver.state.paired_clients response = hap_handler.HAPResponse() @@ -222,10 +334,7 @@ def test_pair_verify_two_invaild_state(driver): handler = hap_handler.HAPServerHandler(driver, "peername") handler.is_encrypted = False - driver.pair( - CLIENT_UUID, - PUBLIC_KEY, - ) + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) assert CLIENT_UUID in driver.state.paired_clients response = hap_handler.HAPResponse() @@ -269,10 +378,7 @@ def test_invalid_pairing_request(driver): handler = hap_handler.HAPServerHandler(driver, "peername") handler.is_encrypted = False - driver.pair( - CLIENT_UUID, - PUBLIC_KEY, - ) + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) assert CLIENT_UUID in driver.state.paired_clients response = hap_handler.HAPResponse() @@ -559,10 +665,7 @@ def test_attempt_to_pair_when_already_paired(driver): handler = hap_handler.HAPServerHandler(driver, "peername") handler.is_encrypted = False - driver.pair( - CLIENT_UUID, - PUBLIC_KEY, - ) + driver.pair(CLIENT_UUID, PUBLIC_KEY, HAP_PERMISSIONS.ADMIN) response = hap_handler.HAPResponse() handler.response = response diff --git a/tests/test_hap_protocol.py b/tests/test_hap_protocol.py index d4bab68d..69e8a4b4 100644 --- a/tests/test_hap_protocol.py +++ b/tests/test_hap_protocol.py @@ -6,7 +6,7 @@ from cryptography.exceptions import InvalidTag import pytest -from pyhap import hap_protocol, hap_handler +from pyhap import hap_handler, hap_protocol from pyhap.accessory import Accessory, Bridge from pyhap.hap_handler import HAPResponse diff --git a/tests/test_state.py b/tests/test_state.py index cb117ae1..989d3b2d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -4,6 +4,7 @@ from cryptography.hazmat.primitives.asymmetric import ed25519 import pytest +from pyhap.const import CLIENT_PROP_PERMS, HAP_PERMISSIONS from pyhap.state import State @@ -45,10 +46,10 @@ def test_setup(): assert mock_gen_mac.called assert mock_gen_pincode.called assert state.port == 51827 - assert state.config_version == 2 + assert state.config_version == 1 -def test_pairing(): +def test_pairing_remove_last_admin(): """Test if pairing methods work.""" with patch("pyhap.util.get_local_address"), patch("pyhap.util.generate_mac"), patch( "pyhap.util.generate_pincode" @@ -58,10 +59,51 @@ def test_pairing(): assert not state.paired assert not state.paired_clients - state.add_paired_client("uuid", "public") + state.add_paired_client("uuid", "public", HAP_PERMISSIONS.ADMIN) assert state.paired assert state.paired_clients == {"uuid": "public"} + assert state.client_properties == {"uuid": {CLIENT_PROP_PERMS: 1}} + state.add_paired_client("uuid2", "public", HAP_PERMISSIONS.USER) + assert state.paired + assert state.paired_clients == {"uuid": "public", "uuid2": "public"} + assert state.client_properties == { + "uuid": {CLIENT_PROP_PERMS: 1}, + "uuid2": {CLIENT_PROP_PERMS: 0}, + } + + # Removing the last admin should remove all non-admins state.remove_paired_client("uuid") assert not state.paired assert not state.paired_clients + + +def test_pairing_two_admins(): + """Test if pairing methods work.""" + with patch("pyhap.util.get_local_address"), patch("pyhap.util.generate_mac"), patch( + "pyhap.util.generate_pincode" + ), patch("pyhap.util.generate_setup_id"): + state = State() + + assert not state.paired + assert not state.paired_clients + + state.add_paired_client("uuid", "public", HAP_PERMISSIONS.ADMIN) + assert state.paired + assert state.paired_clients == {"uuid": "public"} + assert state.client_properties == {"uuid": {CLIENT_PROP_PERMS: 1}} + + state.add_paired_client("uuid2", "public", HAP_PERMISSIONS.ADMIN) + assert state.paired + assert state.paired_clients == {"uuid": "public", "uuid2": "public"} + assert state.client_properties == { + "uuid": {CLIENT_PROP_PERMS: 1}, + "uuid2": {CLIENT_PROP_PERMS: 1}, + } + + # Removing the admin should leave the other admin + state.remove_paired_client("uuid2") + assert state.paired + assert state.paired_clients == {"uuid": "public"} + assert state.client_properties == {"uuid": {CLIENT_PROP_PERMS: 1}} + assert not state.is_admin("uuid2")