From 5d188e8c0ddef6ce633ca702dbdd4a90f2799597 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Wed, 3 Jan 2024 00:03:34 +0000 Subject: [PATCH] Add a limit to WebSocket message size (Fixes #193) --- src/microdot/test_client.py | 2 ++ src/microdot/websocket.py | 43 +++++++++++++++++++++++++++++++------ tests/test_websocket.py | 38 ++++++++++++++++++-------------- 3 files changed, 60 insertions(+), 23 deletions(-) diff --git a/src/microdot/test_client.py b/src/microdot/test_client.py index 1530ee6..a6d7141 100644 --- a/src/microdot/test_client.py +++ b/src/microdot/test_client.py @@ -292,6 +292,8 @@ async def readline(self): async def awrite(self, data): if self.started: h = WebSocket._parse_frame_header(data[0:2]) + if h[1] not in [WebSocket.TEXT, WebSocket.BINARY]: + return if h[3] < 0: data = data[2 - h[3]:] else: diff --git a/src/microdot/websocket.py b/src/microdot/websocket.py index c7b6034..925f7dc 100644 --- a/src/microdot/websocket.py +++ b/src/microdot/websocket.py @@ -1,7 +1,12 @@ import binascii import hashlib -from microdot import Response -from microdot.microdot import MUTED_SOCKET_ERRORS +from microdot import Request, Response +from microdot.microdot import MUTED_SOCKET_ERRORS, print_exception + + +class WebSocketError(Exception): + """Exception raised when an error occurs in a WebSocket connection.""" + pass class WebSocket: @@ -17,6 +22,18 @@ class WebSocket: PING = 9 PONG = 10 + #: Specify the maximum message size that can be received when calling the + #: ``receive()`` method. Messages with payloads that are larger than this + #: size will be rejected and the connection closed. Set to 0 to disable + #: the size check (be aware of potential security issues if you do this), + #: or to -1 to use the value set in + #: ``Request.max_body_length``. The default is -1. + #: + #: Example:: + #: + #: WebSocket.max_message_length = 4 * 1024 # up to 4KB messages + max_message_length = -1 + def __init__(self, request): self.request = request self.closed = False @@ -86,7 +103,7 @@ def _parse_frame_header(cls, header): fin = header[0] & 0x80 opcode = header[0] & 0x0f if fin == 0 or opcode == cls.CONT: # pragma: no cover - raise OSError(32, 'Continuation frames not supported') + raise WebSocketError('Continuation frames not supported') has_mask = header[1] & 0x80 length = header[1] & 0x7f if length == 126: @@ -101,7 +118,7 @@ def _process_websocket_frame(self, opcode, payload): elif opcode == self.BINARY: pass elif opcode == self.CLOSE: - raise OSError(32, 'Websocket connection closed') + raise WebSocketError('Websocket connection closed') elif opcode == self.PING: return self.PONG, payload elif opcode == self.PONG: # pragma: no branch @@ -128,7 +145,7 @@ def _encode_websocket_frame(cls, opcode, payload): async def _read_frame(self): header = await self.request.sock[0].read(2) if len(header) != 2: # pragma: no cover - raise OSError(32, 'Websocket connection closed') + raise WebSocketError('Websocket connection closed') fin, opcode, has_mask, length = self._parse_frame_header(header) if length == -2: length = await self.request.sock[0].read(2) @@ -136,6 +153,10 @@ async def _read_frame(self): elif length == -8: length = await self.request.sock[0].read(8) length = int.from_bytes(length, 'big') + max_allowed_length = Request.max_body_length \ + if self.max_message_length == -1 else self.max_message_length + if length > max_allowed_length: + raise WebSocketError('Message too large') if has_mask: # pragma: no cover mask = await self.request.sock[0].read(4) payload = await self.request.sock[0].read(length) @@ -175,11 +196,19 @@ async def wrapper(request, *args, **kwargs): ws = await upgrade_function(request) try: await f(request, ws, *args, **kwargs) - await ws.close() # pragma: no cover except OSError as exc: if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover raise - return '' + except WebSocketError: + pass + except Exception as exc: + print_exception(exc) + finally: # pragma: no cover + try: + await ws.close() + except Exception: + pass + return Response.already_handled return wrapper diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 4d2a507..9c20682 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,8 +1,8 @@ import asyncio import sys import unittest -from microdot import Microdot -from microdot.websocket import with_websocket, WebSocket +from microdot import Microdot, Request +from microdot.websocket import with_websocket, WebSocket, WebSocketError from microdot.test_client import TestClient @@ -17,6 +17,7 @@ def _run(self, coro): return self.loop.run_until_complete(coro) def test_websocket_echo(self): + WebSocket.max_message_length = 65537 app = Microdot() @app.route('/echo') @@ -26,9 +27,14 @@ async def index(req, ws): data = await ws.receive() await ws.send(data) + @app.route('/divzero') + @with_websocket + async def divzero(req, ws): + 1 / 0 + results = [] - def ws(): + async def ws(): data = yield 'hello' results.append(data) data = yield b'bye' @@ -43,34 +49,34 @@ def ws(): self.assertIsNone(res) self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537]) + res = self._run(client.websocket('/divzero', ws)) + self.assertIsNone(res) + WebSocket.max_message_length = -1 + @unittest.skipIf(sys.implementation.name == 'micropython', 'no support for async generators in MicroPython') - def test_websocket_echo_async_client(self): + def test_websocket_large_message(self): + saved_max_body_length = Request.max_body_length + Request.max_body_length = 10 app = Microdot() @app.route('/echo') @with_websocket async def index(req, ws): - while True: - data = await ws.receive() - await ws.send(data) + data = await ws.receive() + await ws.send(data) results = [] async def ws(): - data = yield 'hello' - results.append(data) - data = yield b'bye' - results.append(data) - data = yield b'*' * 300 - results.append(data) - data = yield b'+' * 65537 + data = yield '0123456789abcdef' results.append(data) client = TestClient(app) res = self._run(client.websocket('/echo', ws)) self.assertIsNone(res) - self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537]) + self.assertEqual(results, []) + Request.max_body_length = saved_max_body_length def test_bad_websocket_request(self): app = Microdot() @@ -106,7 +112,7 @@ def test_process_websocket_frame(self): (None, 'foo')) self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'), (None, b'foo')) - self.assertRaises(OSError, ws._process_websocket_frame, + self.assertRaises(WebSocketError, ws._process_websocket_frame, WebSocket.CLOSE, b'') self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'), (WebSocket.PONG, b'foo'))