Skip to content

Commit

Permalink
Add cancel_scope argument to client API (#22)
Browse files Browse the repository at this point in the history
This optional scope is cancelled when the connection is closed.
  • Loading branch information
mehaase committed Oct 11, 2018
1 parent 79452ce commit c77ee8f
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
9 changes: 9 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,12 @@ async def handler(stream):
async with open_websocket(HOST, port, RESOURCE, use_ssl=False) as client:
with pytest.raises(ConnectionClosed):
await client.get_message()


async def test_connection_cancel_scope(echo_server):
async with trio.open_nursery() as nursery:
async with open_websocket(HOST, echo_server.port, RESOURCE,
use_ssl=False, cancel_scope=nursery.cancel_scope) as conn:
pass
await trio.sleep(0)
assert nursery.cancel_scope.cancel_called
53 changes: 39 additions & 14 deletions trio_websocket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@asynccontextmanager
@async_generator
async def open_websocket(host, port, resource, use_ssl):
async def open_websocket(host, port, resource, use_ssl, *, cancel_scope=None):
'''
Open a WebSocket client connection to a host.
Expand All @@ -40,15 +40,18 @@ async def open_websocket(host, port, resource, use_ssl):
:param int port: the port to connect to
:param str resource: the resource a.k.a. path
:param use_ssl: a bool or SSLContext
:param cancel_scope: A Trio cancel scope that is cancelled when the
connection is closed.
'''
async with trio.open_nursery() as new_nursery:
connection = await connect_websocket(new_nursery, host, port, resource,
use_ssl)
use_ssl, cancel_scope=cancel_scope)
async with connection:
await yield_(connection)


async def connect_websocket(nursery, host, port, resource, use_ssl):
async def connect_websocket(nursery, host, port, resource, use_ssl, *,
cancel_scope=None):
'''
Return a WebSocket client connection to a host.
Expand All @@ -62,6 +65,8 @@ async def connect_websocket(nursery, host, port, resource, use_ssl):
:param int port: the port to connect to
:param str resource: the resource a.k.a. path
:param use_ssl: a bool or SSLContext
:param cancel_scope: A Trio cancel scope that is cancelled when the
connection is closed.
:rtype: WebSocketConnection
'''
if use_ssl == True:
Expand All @@ -86,13 +91,14 @@ async def connect_websocket(nursery, host, port, resource, use_ssl):
host_header = '{}:{}'.format(host, port)
wsproto = wsconnection.WSConnection(wsconnection.CLIENT,
host=host_header, resource=resource)
connection = WebSocketConnection(stream, wsproto, path=resource)
connection = WebSocketConnection(stream, wsproto, path=resource,
cancel_scope=cancel_scope)
nursery.start_soon(connection._reader_task)
await connection._open_handshake.wait()
return connection


def open_websocket_url(url, ssl_context=None):
def open_websocket_url(url, ssl_context=None, *, cancel_scope=None):
'''
Open a WebSocket client connection to a URL.
Expand All @@ -106,12 +112,16 @@ def open_websocket_url(url, ssl_context=None):
:param str url: a WebSocket URL
:param ssl_context: optional ``SSLContext`` used for ``wss:`` URLs
:param cancel_scope: A Trio cancel scope that is cancelled when the
connection is closed.
'''
host, port, resource, ssl_context = _url_to_host(url, ssl_context)
return open_websocket(host, port, resource, ssl_context)
return open_websocket(host, port, resource, ssl_context,
cancel_scope=cancel_scope)


async def connect_websocket_url(nursery, url, ssl_context=None):
async def connect_websocket_url(nursery, url, ssl_context=None, *,
cancel_scope=None):
'''
Return a WebSocket client connection to a URL.
Expand All @@ -126,10 +136,13 @@ async def connect_websocket_url(nursery, url, ssl_context=None):
:param str url: a WebSocket URL
:param ssl_context: optional ``SSLContext`` used for ``wss:`` URLs
:param nursery: a Trio nursery to run background tasks in
:param cancel_scope: A Trio cancel scope that is cancelled when the
connection is closed.
:rtype: WebSocketConnection
'''
host, port, resource, ssl_context = _url_to_host(url, ssl_context)
return await connect_websocket(nursery, host, port, resource, ssl_context)
return await connect_websocket(nursery, host, port, resource, ssl_context,
cancel_scope=None)


def _url_to_host(url, ssl_context):
Expand All @@ -155,7 +168,8 @@ def _url_to_host(url, ssl_context):
return url.host, url.port, resource, ssl_context


async def wrap_client_stream(nursery, stream, host, resource):
async def wrap_client_stream(nursery, stream, host, resource, *,
cancel_scope=None):
'''
Wrap an arbitrary stream in a client-side ``WebSocketConnection``.
Expand All @@ -167,11 +181,14 @@ async def wrap_client_stream(nursery, stream, host, resource):
:param str host: A host string that will be sent in the ``Host:`` header.
:param str resource: A resource string, i.e. the path component to be
accessed on the server.
:param cancel_scope: A Trio cancel scope that is cancelled when the
connection is closed.
:rtype: WebSocketConnection
'''
wsproto = wsconnection.WSConnection(wsconnection.CLIENT, host=host,
resource=resource)
connection = WebSocketConnection(stream, wsproto, path=resource)
connection = WebSocketConnection(stream, wsproto, path=resource,
cancel_scope=cancel_scope)
nursery.start_soon(connection._reader_task)
await connection._open_handshake.wait()
return connection
Expand Down Expand Up @@ -304,13 +321,16 @@ class WebSocketConnection(trio.abc.AsyncResource):

CONNECTION_ID = itertools.count()

def __init__(self, stream, wsproto, path=None):
def __init__(self, stream, wsproto, path=None, cancel_scope=None):
'''
Constructor.
:param SocketStream stream:
:param wsproto: a WSConnection instance
:param client: a Trio cancel scope (only used by the server)
:param SocketStream stream: A stream to use for WebSocket protocol.
:param WSConnection wsproto: A wsproto connection instance.
:param str path: A URL path to request. (Only valid for client
connections.)
:param cancel_scope: A Trio cancel_scope that is cancelled when the
connection is closed.
'''
self._close_reason = None
self._id = next(self.__class__.CONNECTION_ID)
Expand All @@ -321,6 +341,7 @@ def __init__(self, stream, wsproto, path=None):
self._str_message = ''
self._reader_running = True
self._path = path
self._cancel_scope = cancel_scope
self._put_channel, self._get_channel = open_channel(0)
# Set once the WebSocket open handshake takes place, i.e.
# ConnectionRequested for server or ConnectedEstablished for client.
Expand Down Expand Up @@ -446,6 +467,8 @@ def _abort_web_socket(self):
# We didn't really handshake, but we want any task waiting on this event
# (e.g. self.aclose()) to resume.
self._close_handshake.set()
if self._cancel_scope:
self._cancel_scope.cancel()

async def _close_stream(self):
''' Close the TCP connection. '''
Expand All @@ -455,6 +478,8 @@ async def _close_stream(self):
except trio.BrokenResourceError:
# This means the TCP connection is already dead.
pass
if self._cancel_scope:
self._cancel_scope.cancel()

def _close_web_socket(self, code, reason=None):
'''
Expand Down

0 comments on commit c77ee8f

Please sign in to comment.