diff --git a/drf_orjson_renderer/renderers.py b/drf_orjson_renderer/renderers.py index c091a43..aae1454 100644 --- a/drf_orjson_renderer/renderers.py +++ b/drf_orjson_renderer/renderers.py @@ -4,7 +4,14 @@ from decimal import Decimal from typing import Any, Optional +import django import orjson + +if django.VERSION < (5, 0): + from django.db.models.enums import ChoicesMeta as ChoicesType +elif django.VERSION <= (6, 0): + from django.db.models.enums import ChoicesType + from django.utils.functional import Promise from rest_framework.renderers import BaseRenderer from rest_framework.settings import api_settings @@ -50,7 +57,7 @@ def default(obj: Any) -> Any: return str(obj) else: return float(obj) - elif isinstance(obj, (str, uuid.UUID, Promise)): + elif isinstance(obj, (str, uuid.UUID, Promise, ChoicesType)): return str(obj) elif hasattr(obj, "tolist"): return obj.tolist() diff --git a/requirements/dev.txt b/requirements/dev.txt index 8e04366..05537ac 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -7,7 +7,7 @@ ipdb ipython isort mypy==0.910 -numpy==1.21.4 +numpy==1.26.0 pre-commit pytest-cov==3.0.0 twine==4.0.1 diff --git a/tests/test_renderer.py b/tests/test_renderer.py index 73ba6b4..533d21c 100644 --- a/tests/test_renderer.py +++ b/tests/test_renderer.py @@ -9,7 +9,8 @@ import numpy import orjson import pytest -from django.utils.functional import Promise, lazy +from django.db.models import TextChoices +from django.utils.functional import lazy from rest_framework import status from rest_framework.exceptions import ErrorDetail, ParseError from rest_framework.settings import api_settings @@ -38,6 +39,10 @@ def tolist(self): return [1] +class ChoiceObj(TextChoices): + FIELD = "option-one", "Option One" + + UUID_PARAM = uuid.uuid4() string_doubler = lazy(lambda i: i + i, str) @@ -53,10 +58,15 @@ def tolist(self): (IterObj(1), [1], False), (ReturnList([{"1": 1}], serializer=None), [{"1": 1}], False), (ReturnDict({"a": "b"}, serializer=None), {"a": "b"}, False), + (ChoiceObj.FIELD, "option-one", False,) ] -@pytest.mark.parametrize("test_input,expected,coerce_decimal", DATA_PARAMS) +@pytest.mark.parametrize( + "test_input,expected,coerce_decimal", + DATA_PARAMS, + ids=[type(item[0]) for item in DATA_PARAMS], +) def test_built_in_default_method(test_input, expected, coerce_decimal): """Ensure that the built-in default method works for all data types.""" api_settings.COERCE_DECIMAL_TO_STRING = True if coerce_decimal else False diff --git a/version.py b/version.py index 3c1e9cb..008aaa9 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -__version__ = "1.7.1" +__version__ = "1.7.3"