Skip to content

Commit

Permalink
Implement zerocopy writes for the encrypted protocol (#476)
Browse files Browse the repository at this point in the history
* Implement zerocopy writes for the encrypted protocol

With Python 3.12+ and later `transport.writelines` is implemented as [`sendmsg(..., IOV_MAX)`](python/cpython#91166) which allows us to avoid joining the bytes and sending them in one go.

Older Python will effectively do the same thing we do now `b"".join(...)`

* update tests
  • Loading branch information
bdraco authored Nov 3, 2024
1 parent 7cb69d9 commit 20f5151
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 44 deletions.
9 changes: 3 additions & 6 deletions pyhap/hap_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import struct
from struct import Struct
from typing import List
from typing import Iterable, List

from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable as ChaCha20Poly1305
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -112,7 +112,7 @@ def decrypt(self) -> bytes:

return result

def encrypt(self, data: bytes) -> bytes:
def encrypt(self, data: bytes) -> Iterable[bytes]:
"""Encrypt and send the return bytes."""
result: List[bytes] = []
offset = 0
Expand All @@ -127,7 +127,4 @@ def encrypt(self, data: bytes) -> bytes:
offset += length
self._out_count += 1

# Join the result once instead of concatenating each time
# as this is much faster than generating an new immutable
# byte string each time.
return b"".join(result)
return result
2 changes: 1 addition & 1 deletion pyhap/hap_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def write(self, data: bytes) -> None:
self.handler.client_uuid,
data,
)
self.transport.write(result)
self.transport.writelines(result)
else:
logger.debug(
"%s (%s): Send unencrypted: %s",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hap_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_round_trip():
crypto.OUT_CIPHER_INFO = crypto.IN_CIPHER_INFO
crypto.reset(key)

encrypted = bytearray(crypto.encrypt(plaintext))
encrypted = bytearray(b"".join(crypto.encrypt(plaintext)))

# Receive no data
assert crypto.decrypt() == b""
Expand Down
78 changes: 42 additions & 36 deletions tests/test_hap_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,13 @@ def test_get_accessories_with_crypto(driver):
hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b"GET /accessories HTTP/1.1\r\nHost: Bridge\\032C77C47._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long
)

hap_proto.close()
assert b"accessories" in writer.call_args_list[0][0][0]
assert b"accessories" in b"".join(writelines.call_args_list[0][0])


def test_get_characteristics_with_crypto(driver):
Expand All @@ -273,7 +273,7 @@ def test_get_characteristics_with_crypto(driver):
hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b"GET /characteristics?id=3762173001.7 HTTP/1.1\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long
)
Expand All @@ -282,13 +282,15 @@ def test_get_characteristics_with_crypto(driver):
)

hap_proto.close()
assert b"Content-Length:" in writer.call_args_list[0][0][0]
assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[0][0][0]
assert b"-70402" in writer.call_args_list[0][0][0]
joined0 = b"".join(writelines.call_args_list[0][0])
assert b"Content-Length:" in joined0
assert b"Transfer-Encoding: chunked\r\n\r\n" not in joined0
assert b"-70402" in joined0

assert b"Content-Length:" in writer.call_args_list[1][0][0]
assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[1][0][0]
assert b"TestAcc" in writer.call_args_list[1][0][0]
joined1 = b"".join(writelines.call_args_list[1][0])
assert b"Content-Length:" in joined1
assert b"Transfer-Encoding: chunked\r\n\r\n" not in joined1
assert b"TestAcc" in joined1


def test_set_characteristics_with_crypto(driver):
Expand All @@ -309,13 +311,15 @@ def test_set_characteristics_with_crypto(driver):
hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b'PUT /characteristics HTTP/1.1\r\nHost: HASS12\\032AD1C22._hap._tcp.local\r\nContent-Length: 49\r\nContent-Type: application/hap+json\r\n\r\n{"characteristics":[{"aid":1,"iid":9,"ev":true}]}' # pylint: disable=line-too-long
)

hap_proto.close()
assert writer.call_args_list[0][0][0] == b"HTTP/1.1 204 No Content\r\n\r\n"
assert (
b"".join(writelines.call_args_list[0][0]) == b"HTTP/1.1 204 No Content\r\n\r\n"
)


def test_crypto_failure_closes_connection(driver):
Expand Down Expand Up @@ -352,14 +356,14 @@ def test_empty_encrypted_data(driver):

hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True
with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(b"")
hap_proto.data_received(
b"GET /accessories HTTP/1.1\r\nHost: Bridge\\032C77C47._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long
)

hap_proto.close()
assert b"accessories" in writer.call_args_list[0][0][0]
assert b"accessories" in b"".join(writelines.call_args_list[0][0])


def test_http_11_keep_alive(driver):
Expand Down Expand Up @@ -434,13 +438,13 @@ def test_camera_snapshot_without_snapshot_support(driver):
hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b'POST /resource HTTP/1.1\r\nHost: HASS\\032Bridge\\032BROZ\\0323BF435._hap._tcp.local\r\nContent-Length: 79\r\nContent-Type: application/hap+json\r\n\r\n{"image-height":360,"resource-type":"image","image-width":640,"aid":1411620844}' # pylint: disable=line-too-long
)

hap_proto.close()
assert b"-70402" in writer.call_args_list[0][0][0]
assert b"-70402" in b"".join(writelines.call_args_list[0][0])


@pytest.mark.asyncio
Expand All @@ -464,14 +468,14 @@ def _get_snapshot(*_):
hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b'POST /resource HTTP/1.1\r\nHost: HASS\\032Bridge\\032BROZ\\0323BF435._hap._tcp.local\r\nContent-Length: 79\r\nContent-Type: application/hap+json\r\n\r\n{"image-height":360,"resource-type":"image","image-width":640,"aid":1411620844}' # pylint: disable=line-too-long
)
await hap_proto.response.task
await asyncio.sleep(0)

assert b"fakesnap" in writer.call_args_list[0][0][0]
assert b"fakesnap" in b"".join(writelines.call_args_list[0][0])

hap_proto.close()

Expand All @@ -497,14 +501,14 @@ async def _async_get_snapshot(*_):
hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b'POST /resource HTTP/1.1\r\nHost: HASS\\032Bridge\\032BROZ\\0323BF435._hap._tcp.local\r\nContent-Length: 79\r\nContent-Type: application/hap+json\r\n\r\n{"image-height":360,"resource-type":"image","image-width":640,"aid":1411620844}' # pylint: disable=line-too-long
)
await hap_proto.response.task
await asyncio.sleep(0)

assert b"fakesnap" in writer.call_args_list[0][0][0]
assert b"fakesnap" in b"".join(writelines.call_args_list[0][0])

hap_proto.close()

Expand Down Expand Up @@ -532,14 +536,14 @@ async def _async_get_snapshot(*_):
hap_proto.handler.is_encrypted = True

with patch.object(hap_handler, "RESPONSE_TIMEOUT", 0.1), patch.object(
hap_proto.transport, "write"
) as writer:
hap_proto.transport, "writelines"
) as writelines:
hap_proto.data_received(
b'POST /resource HTTP/1.1\r\nHost: HASS\\032Bridge\\032BROZ\\0323BF435._hap._tcp.local\r\nContent-Length: 79\r\nContent-Type: application/hap+json\r\n\r\n{"image-height":360,"resource-type":"image","image-width":640,"aid":1411620844}' # pylint: disable=line-too-long
)
await asyncio.sleep(0.3)

assert b"-70402" in writer.call_args_list[0][0][0]
assert b"-70402" in b"".join(writelines.call_args_list[0][0])

hap_proto.close()

Expand All @@ -564,7 +568,7 @@ def _make_response(*_):
response.shared_key = b"newkey"
return response

with patch.object(hap_proto.transport, "write"), patch.object(
with patch.object(hap_proto.transport, "writelines"), patch.object(
hap_proto.handler, "dispatch", _make_response
):
hap_proto.data_received(
Expand Down Expand Up @@ -635,7 +639,7 @@ async def _async_get_snapshot(*_):
hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b'POST /resource HTTP/1.1\r\nHost: HASS\\032Bridge\\032BROZ\\0323BF435._hap._tcp.local\r\nContent-Length: 79\r\nContent-Type: application/hap+json\r\n\r\n{"image-height":360,"resource-type":"image","image-width":640,"aid":1411620844}' # pylint: disable=line-too-long
)
Expand All @@ -645,7 +649,7 @@ async def _async_get_snapshot(*_):
pass
await asyncio.sleep(0)

assert b"-70402" in writer.call_args_list[0][0][0]
assert b"-70402" in b"".join(writelines.call_args_list[0][0])

hap_proto.close()

Expand All @@ -671,7 +675,7 @@ def _get_snapshot(*_):
hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b'POST /resource HTTP/1.1\r\nHost: HASS\\032Bridge\\032BROZ\\0323BF435._hap._tcp.local\r\nContent-Length: 79\r\nContent-Type: application/hap+json\r\n\r\n{"image-height":360,"resource-type":"image","image-width":640,"aid":1411620844}' # pylint: disable=line-too-long
)
Expand All @@ -681,7 +685,7 @@ def _get_snapshot(*_):
pass
await asyncio.sleep(0)

assert b"-70402" in writer.call_args_list[0][0][0]
assert b"-70402" in b"".join(writelines.call_args_list[0][0])

hap_proto.close()

Expand All @@ -702,14 +706,14 @@ async def test_camera_snapshot_missing_accessory(driver):
hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b'POST /resource HTTP/1.1\r\nHost: HASS\\032Bridge\\032BROZ\\0323BF435._hap._tcp.local\r\nContent-Length: 79\r\nContent-Type: application/hap+json\r\n\r\n{"image-height":360,"resource-type":"image","image-width":640,"aid":1411620844}' # pylint: disable=line-too-long
)
await asyncio.sleep(0)

assert hap_proto.response is None
assert b"-70402" in writer.call_args_list[0][0][0]
assert b"-70402" in b"".join(writelines.call_args_list[0][0])
hap_proto.close()


Expand Down Expand Up @@ -777,20 +781,22 @@ def test_explicit_close(driver: AccessoryDriver):
hap_proto.handler.is_encrypted = True
assert hap_proto.transport.is_closing() is False

with patch.object(hap_proto.transport, "write") as writer:
with patch.object(hap_proto.transport, "writelines") as writelines:
hap_proto.data_received(
b"GET /characteristics?id=3762173001.7 HTTP/1.1\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long
)
hap_proto.data_received(
b"GET /characteristics?id=1.5 HTTP/1.1\r\nConnection: close\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long
)

assert b"Content-Length:" in writer.call_args_list[0][0][0]
assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[0][0][0]
assert b"-70402" in writer.call_args_list[0][0][0]
join0 = b"".join(writelines.call_args_list[0][0])
assert b"Content-Length:" in join0
assert b"Transfer-Encoding: chunked\r\n\r\n" not in join0
assert b"-70402" in join0

assert b"Content-Length:" in writer.call_args_list[1][0][0]
assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[1][0][0]
assert b"TestAcc" in writer.call_args_list[1][0][0]
join1 = b"".join(writelines.call_args_list[1][0])
assert b"Content-Length:" in join1
assert b"Transfer-Encoding: chunked\r\n\r\n" not in join1
assert b"TestAcc" in join1

assert hap_proto.transport.is_closing() is True

0 comments on commit 20f5151

Please sign in to comment.