Skip to content

Commit 59fa2e3

Browse files
committed
Add thread-based implementation.
1 parent 125ffe2 commit 59fa2e3

16 files changed

+3251
-13
lines changed

Diff for: .github/workflows/tests.yml

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ on:
88
branches:
99
- main
1010

11+
env:
12+
WEBSOCKETS_TESTS_TIMEOUT_FACTOR: 10
13+
1114
jobs:
1215
coverage:
1316
name: Run test coverage checks

Diff for: setup.cfg

+1
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ exclude_lines =
3737
raise AssertionError
3838
raise NotImplementedError
3939
self.fail\(".*"\)
40+
@unittest.skip

Diff for: src/websockets/server.py

+2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def accept(self, request: Request) -> Response:
164164
f"Failed to open a WebSocket connection: {exc}.\n",
165165
)
166166
except Exception as exc:
167+
# Handle exceptions raised by user-provided select_subprotocol and
168+
# unexpected errors.
167169
request._exception = exc
168170
self.handshake_exc = exc
169171
self.logger.error("opening handshake failed", exc_info=True)

Diff for: src/websockets/sync/client.py

+327
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
from __future__ import annotations
2+
3+
import socket
4+
import ssl
5+
import threading
6+
from typing import Any, Optional, Sequence, Type
7+
8+
from ..client import ClientProtocol
9+
from ..datastructures import HeadersLike
10+
from ..extensions.base import ClientExtensionFactory
11+
from ..extensions.permessage_deflate import enable_client_permessage_deflate
12+
from ..headers import validate_subprotocols
13+
from ..http import USER_AGENT
14+
from ..http11 import Response
15+
from ..protocol import CONNECTING, OPEN, Event
16+
from ..typing import LoggerLike, Origin, Subprotocol
17+
from ..uri import parse_uri
18+
from .connection import Connection
19+
from .utils import Deadline
20+
21+
22+
__all__ = ["connect", "unix_connect", "ClientConnection"]
23+
24+
25+
class ClientConnection(Connection):
26+
"""
27+
Threaded implementation of a WebSocket client connection.
28+
29+
:class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for
30+
receiving and sending messages.
31+
32+
It supports iteration to receive messages::
33+
34+
for message in websocket:
35+
process(message)
36+
37+
The iterator exits normally when the connection is closed with close code
38+
1000 (OK) or 1001 (going away) or without a close code. It raises a
39+
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
40+
closed with any other code.
41+
42+
Args:
43+
socket: Socket connected to a WebSocket server.
44+
protocol: Sans-I/O connection.
45+
close_timeout: Timeout for closing the connection in seconds.
46+
47+
"""
48+
49+
def __init__(
50+
self,
51+
socket: socket.socket,
52+
protocol: ClientProtocol,
53+
*,
54+
close_timeout: Optional[float] = 10,
55+
) -> None:
56+
self.protocol: ClientProtocol
57+
self.response_rcvd = threading.Event()
58+
super().__init__(
59+
socket,
60+
protocol,
61+
close_timeout=close_timeout,
62+
)
63+
64+
def handshake(
65+
self,
66+
additional_headers: Optional[HeadersLike] = None,
67+
user_agent_header: Optional[str] = USER_AGENT,
68+
timeout: Optional[float] = None,
69+
) -> None:
70+
"""
71+
Perform the opening handshake.
72+
73+
"""
74+
with self.send_context(expected_state=CONNECTING):
75+
self.request = self.protocol.connect()
76+
if additional_headers is not None:
77+
self.request.headers.update(additional_headers)
78+
if user_agent_header is not None:
79+
self.request.headers["User-Agent"] = user_agent_header
80+
self.protocol.send_request(self.request)
81+
82+
if not self.response_rcvd.wait(timeout):
83+
self.close_socket()
84+
self.recv_events_thread.join()
85+
raise TimeoutError("timed out during handshake")
86+
87+
if self.response is None:
88+
self.close_socket()
89+
self.recv_events_thread.join()
90+
raise ConnectionError("connection closed during handshake")
91+
92+
if self.protocol.state is not OPEN:
93+
self.recv_events_thread.join(self.close_timeout)
94+
self.close_socket()
95+
self.recv_events_thread.join()
96+
97+
if self.protocol.handshake_exc is not None:
98+
raise self.protocol.handshake_exc
99+
100+
def process_event(self, event: Event) -> None:
101+
"""
102+
Process one incoming event.
103+
104+
"""
105+
# First event - handshake response.
106+
if self.response is None:
107+
assert isinstance(event, Response)
108+
self.response = event
109+
self.response_rcvd.set()
110+
# Later events - frames.
111+
else:
112+
super().process_event(event)
113+
114+
def recv_events(self) -> None:
115+
"""
116+
Read incoming data from the socket and process events.
117+
118+
"""
119+
try:
120+
super().recv_events()
121+
finally:
122+
# If the connection is closed during the handshake, unblock it.
123+
self.response_rcvd.set()
124+
125+
126+
def connect(
127+
uri: str,
128+
*,
129+
# TCP/TLS — unix and path are only for unix_connect()
130+
sock: Optional[socket.socket] = None,
131+
ssl_context: Optional[ssl.SSLContext] = None,
132+
server_hostname: Optional[str] = None,
133+
unix: bool = False,
134+
path: Optional[str] = None,
135+
# WebSocket
136+
origin: Optional[Origin] = None,
137+
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
138+
subprotocols: Optional[Sequence[Subprotocol]] = None,
139+
additional_headers: Optional[HeadersLike] = None,
140+
user_agent_header: Optional[str] = USER_AGENT,
141+
compression: Optional[str] = "deflate",
142+
# Timeouts
143+
open_timeout: Optional[float] = 10,
144+
close_timeout: Optional[float] = 10,
145+
# Limits
146+
max_size: Optional[int] = 2**20,
147+
# Logging
148+
logger: Optional[LoggerLike] = None,
149+
# Escape hatch for advanced customization
150+
create_connection: Optional[Type[ClientConnection]] = None,
151+
) -> ClientConnection:
152+
"""
153+
Connect to the WebSocket server at ``uri``.
154+
155+
This function returns a :class:`ClientConnection` instance, which you can
156+
use to send and receive messages.
157+
158+
:func:`connect` may be used as a context manager::
159+
160+
async with websockets.sync.client.connect(...) as websocket:
161+
...
162+
163+
The connection is closed automatically when exiting the context.
164+
165+
Args:
166+
uri: URI of the WebSocket server.
167+
sock: Preexisting TCP socket. ``sock`` overrides the host and port
168+
from ``uri``. You may call :func:`socket.create_connection` to
169+
create a suitable TCP socket.
170+
ssl_context: Configuration for enabling TLS on the connection.
171+
server_hostname: Hostname for the TLS handshake. ``server_hostname``
172+
overrides the hostname from ``uri``.
173+
origin: Value of the ``Origin`` header, for servers that require it.
174+
extensions: List of supported extensions, in order in which they
175+
should be negotiated and run.
176+
subprotocols: List of supported subprotocols, in order of decreasing
177+
preference.
178+
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
179+
to the handshake request.
180+
user_agent_header: Value of the ``User-Agent`` request header.
181+
It defaults to ``"Python/x.y.z websockets/X.Y"``.
182+
Setting it to :obj:`None` removes the header.
183+
compression: The "permessage-deflate" extension is enabled by default.
184+
Set ``compression`` to :obj:`None` to disable it. See the
185+
:doc:`compression guide <../../topics/compression>` for details.
186+
open_timeout: Timeout for opening the connection in seconds.
187+
:obj:`None` disables the timeout.
188+
close_timeout: Timeout for closing the connection in seconds.
189+
:obj:`None` disables the timeout.
190+
max_size: Maximum size of incoming messages in bytes.
191+
:obj:`None` disables the limit.
192+
logger: Logger for this client.
193+
It defaults to ``logging.getLogger("websockets.client")``.
194+
See the :doc:`logging guide <../../topics/logging>` for details.
195+
create_connection: Factory for the :class:`ClientConnection` managing
196+
the connection. Set it to a wrapper or a subclass to customize
197+
connection handling.
198+
199+
Raises:
200+
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
201+
InvalidHandshake: If the opening handshake fails.
202+
TimeoutError: If the opening handshake times out.
203+
204+
"""
205+
206+
# Process parameters
207+
208+
wsuri = parse_uri(uri)
209+
if not wsuri.secure and ssl_context is not None:
210+
raise TypeError("ssl_context argument is incompatible with a ws:// URI")
211+
212+
if unix:
213+
if path is None and sock is None:
214+
raise TypeError("missing path argument")
215+
elif path is not None and sock is not None:
216+
raise TypeError("path and sock arguments are incompatible")
217+
else:
218+
assert path is None # private argument, only set by unix_connect()
219+
220+
if subprotocols is not None:
221+
validate_subprotocols(subprotocols)
222+
223+
if compression == "deflate":
224+
extensions = enable_client_permessage_deflate(extensions)
225+
elif compression is not None:
226+
raise ValueError(f"unsupported compression: {compression}")
227+
228+
# Calculate timeouts on the TCP, TLS, and WebSocket handshakes.
229+
# The TCP and TLS timeouts must be set on the socket, then removed
230+
# to avoid conflicting with the WebSocket timeout in handshake().
231+
deadline = Deadline(open_timeout)
232+
233+
if create_connection is None:
234+
create_connection = ClientConnection
235+
236+
try:
237+
# Connect socket
238+
239+
if sock is None:
240+
if unix:
241+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
242+
sock.settimeout(deadline.timeout())
243+
assert path is not None # validated above -- this is for mpypy
244+
sock.connect(path)
245+
else:
246+
sock = socket.create_connection(
247+
(wsuri.host, wsuri.port),
248+
deadline.timeout(),
249+
)
250+
sock.settimeout(None)
251+
252+
# Disable Nagle algorithm
253+
254+
if not unix:
255+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
256+
257+
# Initialize TLS wrapper and perform TLS handshake
258+
259+
if wsuri.secure:
260+
if ssl_context is None:
261+
ssl_context = ssl.create_default_context()
262+
if server_hostname is None:
263+
server_hostname = wsuri.host
264+
sock.settimeout(deadline.timeout())
265+
sock = ssl_context.wrap_socket(sock, server_hostname=server_hostname)
266+
sock.settimeout(None)
267+
268+
# Initialize WebSocket connection
269+
270+
protocol = ClientProtocol(
271+
wsuri,
272+
origin=origin,
273+
extensions=extensions,
274+
subprotocols=subprotocols,
275+
state=CONNECTING,
276+
max_size=max_size,
277+
logger=logger,
278+
)
279+
280+
# Initialize WebSocket protocol
281+
282+
connection = create_connection(
283+
sock,
284+
protocol,
285+
close_timeout=close_timeout,
286+
)
287+
# On failure, handshake() closes the socket and raises an exception.
288+
connection.handshake(
289+
additional_headers,
290+
user_agent_header,
291+
deadline.timeout(),
292+
)
293+
294+
except Exception:
295+
if sock is not None:
296+
sock.close()
297+
raise
298+
299+
return connection
300+
301+
302+
def unix_connect(
303+
path: Optional[str] = None,
304+
uri: Optional[str] = None,
305+
**kwargs: Any,
306+
) -> ClientConnection:
307+
"""
308+
Connect to a WebSocket server listening on a Unix socket.
309+
310+
This function is identical to :func:`connect`, except for the additional
311+
``path`` argument. It's only available on Unix.
312+
313+
It's mainly useful for debugging servers listening on Unix sockets.
314+
315+
Args:
316+
path: File system path to the Unix socket.
317+
uri: URI of the WebSocket server. ``uri`` defaults to
318+
``ws://localhost/`` or, when a ``ssl_context`` is provided, to
319+
``wss://localhost/``.
320+
321+
"""
322+
if uri is None:
323+
if kwargs.get("ssl_context") is None:
324+
uri = "ws://localhost/"
325+
else:
326+
uri = "wss://localhost/"
327+
return connect(uri=uri, unix=True, path=path, **kwargs)

0 commit comments

Comments
 (0)