Skip to content

Commit ac3ea04

Browse files
committed
Add tests for encoding/decoding highly recursive messages
This adds a test case found in the `orjson` repo to ensure that we properly respect recursion limits when encoding or decoding deeply recursive messages. No code change is needed at this time, we properly manage recursion limits already.
1 parent 43c239e commit ac3ea04

File tree

3 files changed

+86
-31
lines changed

3 files changed

+86
-31
lines changed

tests/test_common.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333

3434
import pytest
35-
from utils import temp_module
35+
from utils import temp_module, max_call_depth
3636

3737
try:
3838
import attrs
@@ -295,6 +295,62 @@ class Custom(metaclass=Metaclass):
295295
dec.decode(msg)
296296

297297

298+
@pytest.mark.skipif(
299+
PY312,
300+
reason=(
301+
"Python 3.12 harcodes the C recursion limit, making this "
302+
"behavior harder to test in CI"
303+
),
304+
)
305+
class TestRecursion:
306+
@staticmethod
307+
def nested(n, is_array):
308+
if is_array:
309+
obj = []
310+
for _ in range(n):
311+
obj = [obj]
312+
else:
313+
obj = {}
314+
for _ in range(n):
315+
obj = {"": obj}
316+
return obj
317+
318+
@pytest.mark.parametrize("is_array", [True, False])
319+
def test_encode_highly_recursive_msg_errors(self, is_array, proto):
320+
N = 200
321+
obj = self.nested(N, is_array)
322+
323+
# Errors if above the recursion limit
324+
with max_call_depth(N // 2):
325+
with pytest.raises(RecursionError):
326+
proto.encode(obj)
327+
328+
# Works if below the recursion limit
329+
with max_call_depth(N * 2):
330+
proto.encode(obj)
331+
332+
@pytest.mark.parametrize("is_array", [True, False])
333+
def test_decode_highly_recursive_msg_errors(self, is_array, proto):
334+
"""Ensure recursion is properly handled when decoding.
335+
Test case seen in https://github.com/ijl/orjson/issues/458."""
336+
N = 200
337+
obj = self.nested(N, is_array)
338+
339+
with max_call_depth(N * 2):
340+
msg = proto.encode(obj)
341+
342+
# Errors if above the recursion limit
343+
with max_call_depth(N // 2):
344+
with pytest.raises(RecursionError):
345+
proto.decode(msg)
346+
347+
# Works if below the recursion limit
348+
with max_call_depth(N * 2):
349+
obj2 = proto.decode(msg)
350+
351+
assert obj2
352+
353+
298354
class TestThreadSafe:
299355
def test_encode_threadsafe(self, proto):
300356
class Nested:

tests/test_convert.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
import decimal
33
import enum
44
import gc
5-
import inspect
65
import math
76
import sys
87
import uuid
98
from base64 import b64encode
109
from collections.abc import MutableMapping
11-
from contextlib import contextmanager
1210
from dataclasses import dataclass, field
1311
from typing import (
1412
Any,
@@ -25,7 +23,7 @@
2523
)
2624

2725
import pytest
28-
from utils import temp_module
26+
from utils import temp_module, max_call_depth
2927

3028
import msgspec
3129
from msgspec import Meta, Struct, ValidationError, convert, to_builtins
@@ -196,33 +194,6 @@ def assert_eq(x, y):
196194
assert x == y
197195

198196

199-
@contextmanager
200-
def max_call_depth(n):
201-
cur_depth = len(inspect.stack(0))
202-
orig = sys.getrecursionlimit()
203-
try:
204-
# Our measure of the current stack depth can be off by a bit. Trying to
205-
# set a recursionlimit < the current depth will raise a RecursionError.
206-
# We just try again with a slightly higher limit, bailing after an
207-
# unreasonable amount of adjustments.
208-
#
209-
# Note that python 3.8 also has a minimum recursion limit of 64, so
210-
# there's some additional fiddliness there.
211-
for i in range(64):
212-
try:
213-
sys.setrecursionlimit(cur_depth + i + n)
214-
break
215-
except RecursionError:
216-
pass
217-
else:
218-
raise ValueError(
219-
"Failed to set low recursion limit, something is wrong here"
220-
)
221-
yield
222-
finally:
223-
sys.setrecursionlimit(orig)
224-
225-
226197
def roundtrip(obj, typ):
227198
return convert(to_builtins(obj), typ)
228199

tests/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import inspect
23
import textwrap
34
import types
45
import uuid
@@ -20,3 +21,30 @@ def temp_module(code):
2021
yield mod
2122
finally:
2223
sys.modules.pop(name, None)
24+
25+
26+
@contextmanager
27+
def max_call_depth(n):
28+
cur_depth = len(inspect.stack(0))
29+
orig = sys.getrecursionlimit()
30+
try:
31+
# Our measure of the current stack depth can be off by a bit. Trying to
32+
# set a recursionlimit < the current depth will raise a RecursionError.
33+
# We just try again with a slightly higher limit, bailing after an
34+
# unreasonable amount of adjustments.
35+
#
36+
# Note that python 3.8 also has a minimum recursion limit of 64, so
37+
# there's some additional fiddliness there.
38+
for i in range(64):
39+
try:
40+
sys.setrecursionlimit(cur_depth + i + n)
41+
break
42+
except RecursionError:
43+
pass
44+
else:
45+
raise ValueError(
46+
"Failed to set low recursion limit, something is wrong here"
47+
)
48+
yield
49+
finally:
50+
sys.setrecursionlimit(orig)

0 commit comments

Comments
 (0)