From f6896676a657a9a9039338c96f5d62ef90960487 Mon Sep 17 00:00:00 2001 From: Lewis Juggins Date: Thu, 22 Jun 2017 13:06:52 +0100 Subject: [PATCH] Cache protocol, fix observation issues (#37) * Cache protocol for better performance, and to fix observation issues * Lint * Let the observations know we've lost connection to the protocol * Lint * Address comment * Address comment * Lint --- example_async.py | 30 ++++++++-------- pytradfri/api/aiocoap_api.py | 67 +++++++++++++++++++++++++----------- pytradfri/device.py | 5 +-- 3 files changed, 64 insertions(+), 38 deletions(-) diff --git a/example_async.py b/example_async.py index ee4006a8..3b37795b 100644 --- a/example_async.py +++ b/example_async.py @@ -61,27 +61,28 @@ def observe_callback(updated_device): def observe_err_callback(err): print('observe error:', err) - observe_command = light.observe(observe_callback, observe_err_callback, - duration=120) - # Start observation as a second task on the loop. - observe_future = ensure_future(api(observe_command)) - # Yield to allow observing to start. - yield from asyncio.sleep(0) - - # Example 1: checks state of the light 2 (true=on) + for light in lights: + observe_command = light.observe(observe_callback, observe_err_callback, + duration=120) + # Start observation as a second task on the loop. + ensure_future(api(observe_command)) + # Yield to allow observing to start. + yield from asyncio.sleep(0) + + # Example 1: checks state of the light 0 (true=on) print("Is on:", light.light_control.lights[0].state) - # Example 2: get dimmer level of light 2 + # Example 2: get dimmer level of light 0 print("Dimmer:", light.light_control.lights[0].dimmer) - # Example 3: What is the name of light 2 + # Example 3: What is the name of light 0 print("Name:", light.name) - # Example 4: Set the light level of light 2 + # Example 4: Set the light level of light 0 dim_command = light.light_control.set_dimmer(255) yield from api(dim_command) - # Example 5: Change color of light 2 + # Example 5: Change color of light 0 # f5faf6 = cold | f1e0b5 = normal | efd275 = warm color_command = light.light_control.set_hex_color('efd275') yield from api(color_command) @@ -96,9 +97,8 @@ def observe_err_callback(err): yield from api(dim_command_2) print("Waiting for observation to end (2 mins)") - print("Try altering the light (%s) in the app, and watch the events!" % - light.name) - yield from observe_future + print("Try altering any light in the app, and watch the events!") + yield from asyncio.sleep(120) asyncio.get_event_loop().run_until_complete(run()) diff --git a/pytradfri/api/aiocoap_api.py b/pytradfri/api/aiocoap_api.py index 03258314..54dd5569 100644 --- a/pytradfri/api/aiocoap_api.py +++ b/pytradfri/api/aiocoap_api.py @@ -3,9 +3,8 @@ import json import logging -import aiocoap from aiocoap import Message, Context -from aiocoap.error import RequestTimedOut +from aiocoap.error import RequestTimedOut, Error, ConstructionRenderableError from aiocoap.numbers.codes import Code from aiocoap.transports import tinydtls @@ -14,8 +13,6 @@ _LOGGER = logging.getLogger(__name__) -aiocoap.numbers.constants.MAX_RETRANSMIT = 10 - class PatchedDTLSSecurityStore: """Patched DTLS store in lieu of impl.""" @@ -43,15 +40,50 @@ def api_factory(host, security_code, loop=None): if loop is None: loop = asyncio.get_event_loop() - security_code = security_code.encode('utf-8') + PatchedDTLSSecurityStore.SECRET_PSK = security_code.encode('utf-8') - PatchedDTLSSecurityStore.SECRET_PSK = security_code + _observations_err_callbacks = [] + _protocol = yield from Context.create_client_context(loop=loop) @asyncio.coroutine def _get_protocol(): """Get the protocol for the request.""" - protocol = yield from Context.create_client_context(loop=loop) - return protocol + nonlocal _protocol + if not _protocol: + _protocol = yield from Context.create_client_context(loop=loop) + return _protocol + + @asyncio.coroutine + def _reset_protocol(exc): + """Reset the protocol if an error occurs. + This can be removed when chrysn/aiocoap#79 is closed.""" + # Be responsible and clean up. + protocol = yield from _get_protocol() + yield from protocol.shutdown() + nonlocal _protocol + _protocol = None + # Let any observers know the protocol has been shutdown. + nonlocal _observations_err_callbacks + for ob_error in _observations_err_callbacks: + ob_error(exc) + _observations_err_callbacks.clear() + + @asyncio.coroutine + def _get_response(msg): + """Perform the request, get the response.""" + try: + protocol = yield from _get_protocol() + pr = protocol.request(msg) + r = yield from pr.response + return (pr, r) + except ConstructionRenderableError as e: + raise ClientError("There was an error with the request.", e) + except RequestTimedOut as e: + yield from _reset_protocol(e) + raise RequestTimeout('Request timed out.', e) + except Error as e: + yield from _reset_protocol(e) + raise ServerError("There was an error with the request.", e) @asyncio.coroutine def _execute(api_command): @@ -80,11 +112,7 @@ def _execute(api_command): msg = Message(code=api_method, uri=url, **kwargs) - try: - protocol = yield from _get_protocol() - res = yield from protocol.request(msg).response - except RequestTimedOut: - raise RequestTimeout('Request timed out.') + _, res = yield from _get_response(msg) api_command.result = _process_output(res, parse_json) @@ -98,10 +126,7 @@ def request(*api_commands): return result commands = (_execute(api_command) for api_command in api_commands) - - command_results = [] - for command in commands: - command_results.append((yield from command)) + command_results = yield from asyncio.gather(*commands, loop=loop) return command_results @@ -114,11 +139,9 @@ def _observe(api_command): msg = Message(code=Code.GET, uri=url, observe=duration) - protocol = yield from _get_protocol() - pr = protocol.request(msg) - # Note that this is necessary to start observing - r = yield from pr.response + pr, r = yield from _get_response(msg) + api_command.result = _process_output(r) def success_callback(res): @@ -130,6 +153,8 @@ def error_callback(ex): ob = pr.observation ob.register_callback(success_callback) ob.register_errback(error_callback) + nonlocal _observations_err_callbacks + _observations_err_callbacks.append(ob.error) # This will cause a RequestError to be raised if credentials invalid yield from request(Command('get', ['status'])) diff --git a/pytradfri/device.py b/pytradfri/device.py index c352bec9..71390eaf 100644 --- a/pytradfri/device.py +++ b/pytradfri/device.py @@ -229,9 +229,10 @@ def raw(self): def __repr__(self): state = "on" if self.state else "off" return "".format(self.index, state, self.dimmer, self.hex_color, - self.xy_color) + ">".format(self.index, self.device.name, state, self.dimmer, + self.hex_color, self.xy_color)