From c77ee8fab9c3fd5dfe177b762ee0d5a9d66d5c3e Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Thu, 11 Oct 2018 11:15:34 -0400 Subject: [PATCH] Add cancel_scope argument to client API (#22) This optional scope is cancelled when the connection is closed. --- tests/test_connection.py | 9 +++++++ trio_websocket/__init__.py | 53 ++++++++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 4c4b6ba..5ee1889 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -274,3 +274,12 @@ async def handler(stream): async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client: with pytest.raises(ConnectionClosed): await client.get_message() + + +async def test_connection_cancel_scope(echo_server): + async with trio.open_nursery() as nursery: + async with open_websocket(HOST, echo_server.port, RESOURCE, + use_ssl=False, cancel_scope=nursery.cancel_scope) as conn: + pass + await trio.sleep(0) + assert nursery.cancel_scope.cancel_called diff --git a/trio_websocket/__init__.py b/trio_websocket/__init__.py index c510553..a628e6f 100644 --- a/trio_websocket/__init__.py +++ b/trio_websocket/__init__.py @@ -24,7 +24,7 @@ @asynccontextmanager @async_generator -async def open_websocket(host, port, resource, use_ssl): +async def open_websocket(host, port, resource, use_ssl, *, cancel_scope=None): ''' Open a WebSocket client connection to a host. @@ -40,15 +40,18 @@ async def open_websocket(host, port, resource, use_ssl): :param int port: the port to connect to :param str resource: the resource a.k.a. path :param use_ssl: a bool or SSLContext + :param cancel_scope: A Trio cancel scope that is cancelled when the + connection is closed. ''' async with trio.open_nursery() as new_nursery: connection = await connect_websocket(new_nursery, host, port, resource, - use_ssl) + use_ssl, cancel_scope=cancel_scope) async with connection: await yield_(connection) -async def connect_websocket(nursery, host, port, resource, use_ssl): +async def connect_websocket(nursery, host, port, resource, use_ssl, *, + cancel_scope=None): ''' Return a WebSocket client connection to a host. @@ -62,6 +65,8 @@ async def connect_websocket(nursery, host, port, resource, use_ssl): :param int port: the port to connect to :param str resource: the resource a.k.a. path :param use_ssl: a bool or SSLContext + :param cancel_scope: A Trio cancel scope that is cancelled when the + connection is closed. :rtype: WebSocketConnection ''' if use_ssl == True: @@ -86,13 +91,14 @@ async def connect_websocket(nursery, host, port, resource, use_ssl): host_header = '{}:{}'.format(host, port) wsproto = wsconnection.WSConnection(wsconnection.CLIENT, host=host_header, resource=resource) - connection = WebSocketConnection(stream, wsproto, path=resource) + connection = WebSocketConnection(stream, wsproto, path=resource, + cancel_scope=cancel_scope) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection -def open_websocket_url(url, ssl_context=None): +def open_websocket_url(url, ssl_context=None, *, cancel_scope=None): ''' Open a WebSocket client connection to a URL. @@ -106,12 +112,16 @@ def open_websocket_url(url, ssl_context=None): :param str url: a WebSocket URL :param ssl_context: optional ``SSLContext`` used for ``wss:`` URLs + :param cancel_scope: A Trio cancel scope that is cancelled when the + connection is closed. ''' host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return open_websocket(host, port, resource, ssl_context) + return open_websocket(host, port, resource, ssl_context, + cancel_scope=cancel_scope) -async def connect_websocket_url(nursery, url, ssl_context=None): +async def connect_websocket_url(nursery, url, ssl_context=None, *, + cancel_scope=None): ''' Return a WebSocket client connection to a URL. @@ -126,10 +136,13 @@ async def connect_websocket_url(nursery, url, ssl_context=None): :param str url: a WebSocket URL :param ssl_context: optional ``SSLContext`` used for ``wss:`` URLs :param nursery: a Trio nursery to run background tasks in + :param cancel_scope: A Trio cancel scope that is cancelled when the + connection is closed. :rtype: WebSocketConnection ''' host, port, resource, ssl_context = _url_to_host(url, ssl_context) - return await connect_websocket(nursery, host, port, resource, ssl_context) + return await connect_websocket(nursery, host, port, resource, ssl_context, + cancel_scope=None) def _url_to_host(url, ssl_context): @@ -155,7 +168,8 @@ def _url_to_host(url, ssl_context): return url.host, url.port, resource, ssl_context -async def wrap_client_stream(nursery, stream, host, resource): +async def wrap_client_stream(nursery, stream, host, resource, *, + cancel_scope=None): ''' Wrap an arbitrary stream in a client-side ``WebSocketConnection``. @@ -167,11 +181,14 @@ async def wrap_client_stream(nursery, stream, host, resource): :param str host: A host string that will be sent in the ``Host:`` header. :param str resource: A resource string, i.e. the path component to be accessed on the server. + :param cancel_scope: A Trio cancel scope that is cancelled when the + connection is closed. :rtype: WebSocketConnection ''' wsproto = wsconnection.WSConnection(wsconnection.CLIENT, host=host, resource=resource) - connection = WebSocketConnection(stream, wsproto, path=resource) + connection = WebSocketConnection(stream, wsproto, path=resource, + cancel_scope=cancel_scope) nursery.start_soon(connection._reader_task) await connection._open_handshake.wait() return connection @@ -304,13 +321,16 @@ class WebSocketConnection(trio.abc.AsyncResource): CONNECTION_ID = itertools.count() - def __init__(self, stream, wsproto, path=None): + def __init__(self, stream, wsproto, path=None, cancel_scope=None): ''' Constructor. - :param SocketStream stream: - :param wsproto: a WSConnection instance - :param client: a Trio cancel scope (only used by the server) + :param SocketStream stream: A stream to use for WebSocket protocol. + :param WSConnection wsproto: A wsproto connection instance. + :param str path: A URL path to request. (Only valid for client + connections.) + :param cancel_scope: A Trio cancel_scope that is cancelled when the + connection is closed. ''' self._close_reason = None self._id = next(self.__class__.CONNECTION_ID) @@ -321,6 +341,7 @@ def __init__(self, stream, wsproto, path=None): self._str_message = '' self._reader_running = True self._path = path + self._cancel_scope = cancel_scope self._put_channel, self._get_channel = open_channel(0) # Set once the WebSocket open handshake takes place, i.e. # ConnectionRequested for server or ConnectedEstablished for client. @@ -446,6 +467,8 @@ def _abort_web_socket(self): # We didn't really handshake, but we want any task waiting on this event # (e.g. self.aclose()) to resume. self._close_handshake.set() + if self._cancel_scope: + self._cancel_scope.cancel() async def _close_stream(self): ''' Close the TCP connection. ''' @@ -455,6 +478,8 @@ async def _close_stream(self): except trio.BrokenResourceError: # This means the TCP connection is already dead. pass + if self._cancel_scope: + self._cancel_scope.cancel() def _close_web_socket(self, code, reason=None): '''