Skip to content

Commit

Permalink
Add a limit to WebSocket message size (Fixes #193)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jan 3, 2024
1 parent b80b6b6 commit 5d188e8
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 23 deletions.
2 changes: 2 additions & 0 deletions src/microdot/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ async def readline(self):
async def awrite(self, data):
if self.started:
h = WebSocket._parse_frame_header(data[0:2])
if h[1] not in [WebSocket.TEXT, WebSocket.BINARY]:
return
if h[3] < 0:
data = data[2 - h[3]:]
else:
Expand Down
43 changes: 36 additions & 7 deletions src/microdot/websocket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import binascii
import hashlib
from microdot import Response
from microdot.microdot import MUTED_SOCKET_ERRORS
from microdot import Request, Response
from microdot.microdot import MUTED_SOCKET_ERRORS, print_exception


class WebSocketError(Exception):
"""Exception raised when an error occurs in a WebSocket connection."""
pass


class WebSocket:
Expand All @@ -17,6 +22,18 @@ class WebSocket:
PING = 9
PONG = 10

#: Specify the maximum message size that can be received when calling the
#: ``receive()`` method. Messages with payloads that are larger than this
#: size will be rejected and the connection closed. Set to 0 to disable
#: the size check (be aware of potential security issues if you do this),
#: or to -1 to use the value set in
#: ``Request.max_body_length``. The default is -1.
#:
#: Example::
#:
#: WebSocket.max_message_length = 4 * 1024 # up to 4KB messages
max_message_length = -1

def __init__(self, request):
self.request = request
self.closed = False
Expand Down Expand Up @@ -86,7 +103,7 @@ def _parse_frame_header(cls, header):
fin = header[0] & 0x80
opcode = header[0] & 0x0f
if fin == 0 or opcode == cls.CONT: # pragma: no cover
raise OSError(32, 'Continuation frames not supported')
raise WebSocketError('Continuation frames not supported')
has_mask = header[1] & 0x80
length = header[1] & 0x7f
if length == 126:
Expand All @@ -101,7 +118,7 @@ def _process_websocket_frame(self, opcode, payload):
elif opcode == self.BINARY:
pass
elif opcode == self.CLOSE:
raise OSError(32, 'Websocket connection closed')
raise WebSocketError('Websocket connection closed')
elif opcode == self.PING:
return self.PONG, payload
elif opcode == self.PONG: # pragma: no branch
Expand All @@ -128,14 +145,18 @@ def _encode_websocket_frame(cls, opcode, payload):
async def _read_frame(self):
header = await self.request.sock[0].read(2)
if len(header) != 2: # pragma: no cover
raise OSError(32, 'Websocket connection closed')
raise WebSocketError('Websocket connection closed')
fin, opcode, has_mask, length = self._parse_frame_header(header)
if length == -2:
length = await self.request.sock[0].read(2)
length = int.from_bytes(length, 'big')
elif length == -8:
length = await self.request.sock[0].read(8)
length = int.from_bytes(length, 'big')
max_allowed_length = Request.max_body_length \
if self.max_message_length == -1 else self.max_message_length
if length > max_allowed_length:
raise WebSocketError('Message too large')
if has_mask: # pragma: no cover
mask = await self.request.sock[0].read(4)
payload = await self.request.sock[0].read(length)
Expand Down Expand Up @@ -175,11 +196,19 @@ async def wrapper(request, *args, **kwargs):
ws = await upgrade_function(request)
try:
await f(request, ws, *args, **kwargs)
await ws.close() # pragma: no cover
except OSError as exc:
if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover
raise
return ''
except WebSocketError:
pass
except Exception as exc:
print_exception(exc)
finally: # pragma: no cover
try:
await ws.close()
except Exception:
pass
return Response.already_handled
return wrapper


Expand Down
38 changes: 22 additions & 16 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import sys
import unittest
from microdot import Microdot
from microdot.websocket import with_websocket, WebSocket
from microdot import Microdot, Request
from microdot.websocket import with_websocket, WebSocket, WebSocketError
from microdot.test_client import TestClient


Expand All @@ -17,6 +17,7 @@ def _run(self, coro):
return self.loop.run_until_complete(coro)

def test_websocket_echo(self):
WebSocket.max_message_length = 65537
app = Microdot()

@app.route('/echo')
Expand All @@ -26,9 +27,14 @@ async def index(req, ws):
data = await ws.receive()
await ws.send(data)

@app.route('/divzero')
@with_websocket
async def divzero(req, ws):
1 / 0

results = []

def ws():
async def ws():
data = yield 'hello'
results.append(data)
data = yield b'bye'
Expand All @@ -43,34 +49,34 @@ def ws():
self.assertIsNone(res)
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])

res = self._run(client.websocket('/divzero', ws))
self.assertIsNone(res)
WebSocket.max_message_length = -1

@unittest.skipIf(sys.implementation.name == 'micropython',
'no support for async generators in MicroPython')
def test_websocket_echo_async_client(self):
def test_websocket_large_message(self):
saved_max_body_length = Request.max_body_length
Request.max_body_length = 10
app = Microdot()

@app.route('/echo')
@with_websocket
async def index(req, ws):
while True:
data = await ws.receive()
await ws.send(data)
data = await ws.receive()
await ws.send(data)

results = []

async def ws():
data = yield 'hello'
results.append(data)
data = yield b'bye'
results.append(data)
data = yield b'*' * 300
results.append(data)
data = yield b'+' * 65537
data = yield '0123456789abcdef'
results.append(data)

client = TestClient(app)
res = self._run(client.websocket('/echo', ws))
self.assertIsNone(res)
self.assertEqual(results, ['hello', b'bye', b'*' * 300, b'+' * 65537])
self.assertEqual(results, [])
Request.max_body_length = saved_max_body_length

def test_bad_websocket_request(self):
app = Microdot()
Expand Down Expand Up @@ -106,7 +112,7 @@ def test_process_websocket_frame(self):
(None, 'foo'))
self.assertEqual(ws._process_websocket_frame(WebSocket.BINARY, b'foo'),
(None, b'foo'))
self.assertRaises(OSError, ws._process_websocket_frame,
self.assertRaises(WebSocketError, ws._process_websocket_frame,
WebSocket.CLOSE, b'')
self.assertEqual(ws._process_websocket_frame(WebSocket.PING, b'foo'),
(WebSocket.PONG, b'foo'))
Expand Down

0 comments on commit 5d188e8

Please sign in to comment.