diff --git a/tests/test_common.py b/tests/test_common.py index a4531c95..132e3dcd 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -32,7 +32,7 @@ ) import pytest -from utils import temp_module +from utils import temp_module, max_call_depth try: import attrs @@ -295,6 +295,62 @@ class Custom(metaclass=Metaclass): dec.decode(msg) +@pytest.mark.skipif( + PY312, + reason=( + "Python 3.12 harcodes the C recursion limit, making this " + "behavior harder to test in CI" + ), +) +class TestRecursion: + @staticmethod + def nested(n, is_array): + if is_array: + obj = [] + for _ in range(n): + obj = [obj] + else: + obj = {} + for _ in range(n): + obj = {"": obj} + return obj + + @pytest.mark.parametrize("is_array", [True, False]) + def test_encode_highly_recursive_msg_errors(self, is_array, proto): + N = 200 + obj = self.nested(N, is_array) + + # Errors if above the recursion limit + with max_call_depth(N // 2): + with pytest.raises(RecursionError): + proto.encode(obj) + + # Works if below the recursion limit + with max_call_depth(N * 2): + proto.encode(obj) + + @pytest.mark.parametrize("is_array", [True, False]) + def test_decode_highly_recursive_msg_errors(self, is_array, proto): + """Ensure recursion is properly handled when decoding. + Test case seen in https://github.com/ijl/orjson/issues/458.""" + N = 200 + obj = self.nested(N, is_array) + + with max_call_depth(N * 2): + msg = proto.encode(obj) + + # Errors if above the recursion limit + with max_call_depth(N // 2): + with pytest.raises(RecursionError): + proto.decode(msg) + + # Works if below the recursion limit + with max_call_depth(N * 2): + obj2 = proto.decode(msg) + + assert obj2 + + class TestThreadSafe: def test_encode_threadsafe(self, proto): class Nested: diff --git a/tests/test_convert.py b/tests/test_convert.py index 959192c3..95d496e3 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -2,13 +2,11 @@ import decimal import enum import gc -import inspect import math import sys import uuid from base64 import b64encode from collections.abc import MutableMapping -from contextlib import contextmanager from dataclasses import dataclass, field from typing import ( Any, @@ -25,7 +23,7 @@ ) import pytest -from utils import temp_module +from utils import temp_module, max_call_depth import msgspec from msgspec import Meta, Struct, ValidationError, convert, to_builtins @@ -196,33 +194,6 @@ def assert_eq(x, y): assert x == y -@contextmanager -def max_call_depth(n): - cur_depth = len(inspect.stack(0)) - orig = sys.getrecursionlimit() - try: - # Our measure of the current stack depth can be off by a bit. Trying to - # set a recursionlimit < the current depth will raise a RecursionError. - # We just try again with a slightly higher limit, bailing after an - # unreasonable amount of adjustments. - # - # Note that python 3.8 also has a minimum recursion limit of 64, so - # there's some additional fiddliness there. - for i in range(64): - try: - sys.setrecursionlimit(cur_depth + i + n) - break - except RecursionError: - pass - else: - raise ValueError( - "Failed to set low recursion limit, something is wrong here" - ) - yield - finally: - sys.setrecursionlimit(orig) - - def roundtrip(obj, typ): return convert(to_builtins(obj), typ) diff --git a/tests/utils.py b/tests/utils.py index a64c70b4..e95788ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,5 @@ import sys +import inspect import textwrap import types import uuid @@ -20,3 +21,30 @@ def temp_module(code): yield mod finally: sys.modules.pop(name, None) + + +@contextmanager +def max_call_depth(n): + cur_depth = len(inspect.stack(0)) + orig = sys.getrecursionlimit() + try: + # Our measure of the current stack depth can be off by a bit. Trying to + # set a recursionlimit < the current depth will raise a RecursionError. + # We just try again with a slightly higher limit, bailing after an + # unreasonable amount of adjustments. + # + # Note that python 3.8 also has a minimum recursion limit of 64, so + # there's some additional fiddliness there. + for i in range(64): + try: + sys.setrecursionlimit(cur_depth + i + n) + break + except RecursionError: + pass + else: + raise ValueError( + "Failed to set low recursion limit, something is wrong here" + ) + yield + finally: + sys.setrecursionlimit(orig)