From 5191babf0f33484a0f52aa6318d459efcbfdff96 Mon Sep 17 00:00:00 2001 From: Dawid Kraczkowski Date: Mon, 30 Oct 2023 09:56:42 +0100 Subject: [PATCH] Support for collections.UserString, accept any in Encoder, Decoder --- .gitignore | 1 + chili/decoder.py | 6 ++++- chili/encoder.py | 6 ++++- chili/typing.py | 4 +++ pyproject.toml | 2 +- tests/test_encoder.py | 42 ++++++++++++++++++++++++++++--- tests/usecases/userstring_test.py | 32 +++++++++++++++++++++++ 7 files changed, 87 insertions(+), 6 deletions(-) create mode 100644 tests/usecases/userstring_test.py diff --git a/.gitignore b/.gitignore index ef69993..c5f1bcf 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ .pytest_cache __pycache__ .DS_Store +/dist diff --git a/chili/decoder.py b/chili/decoder.py index b0a9ab0..f7d13ec 100644 --- a/chili/decoder.py +++ b/chili/decoder.py @@ -50,6 +50,7 @@ is_newtype, is_optional, is_typed_dict, + is_user_string, map_generic_type, resolve_forward_reference, unpack_optional, @@ -470,6 +471,9 @@ def build_type_decoder(a_type: Type, extra_decoders: TypeDecoders = None, module if is_class(origin_type) and is_typed_dict(origin_type): return TypedDictDecoder(origin_type, extra_decoders) + if is_class(origin_type) and is_user_string(origin_type): + return ProxyDecoder[origin_type](origin_type) + if origin_type is Union: type_args = get_type_args(a_type) if len(type_args) == 2 and type_args[-1] is type(None): # type: ignore @@ -568,7 +572,7 @@ def __class_getitem__(cls, item: Type) -> Type[Decoder]: # noqa: E501 item = decodable(item) if not hasattr(item, _DECODABLE): - raise DecoderError.invalid_type + item = decodable(item) return type( # type: ignore f"{cls.__qualname__}[{item.__module__}.{item.__qualname__}]", diff --git a/chili/encoder.py b/chili/encoder.py index e87bd5a..ad5dfda 100644 --- a/chili/encoder.py +++ b/chili/encoder.py @@ -33,6 +33,7 @@ is_newtype, is_optional, is_typed_dict, + is_user_string, map_generic_type, resolve_forward_reference, unpack_optional, @@ -379,6 +380,9 @@ def build_type_encoder(a_type: Type, extra_encoders: TypeEncoders = None, module if is_class(origin_type) and is_typed_dict(origin_type): return TypedDictEncoder(origin_type, extra_encoders) + if is_class(origin_type) and is_user_string(origin_type): + return ProxyEncoder[str](str) + if origin_type is Union: type_args = get_type_args(a_type) if len(type_args) == 2 and type_args[-1] is type(None): @@ -476,7 +480,7 @@ def __class_getitem__(cls, item: Any) -> Type[Encoder]: # noqa: E501 item = encodable(item) if not hasattr(item, _ENCODABLE): - raise EncoderError.invalid_type + item = encodable(item) return type( # type: ignore f"{cls.__qualname__}[{item.__module__}.{item.__qualname__}]", diff --git a/chili/typing.py b/chili/typing.py index c020988..2a1c9c6 100644 --- a/chili/typing.py +++ b/chili/typing.py @@ -2,6 +2,7 @@ import sys import typing +from collections import UserString from dataclasses import MISSING, Field, InitVar, is_dataclass from enum import Enum from inspect import isclass as is_class @@ -113,6 +114,9 @@ def is_typed_dict(type_name: Type) -> bool: return issubclass(type_name, dict) and hasattr(type_name, "__annotations__") +def is_user_string(type_name: Type) -> bool: + return issubclass(type_name, UserString) + def map_generic_type(type_name: Any, type_map: Dict[Any, Any]) -> Any: if not type_map: return type_name diff --git a/pyproject.toml b/pyproject.toml index 4736bb2..18a46eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ license = "MIT" name = "chili" readme = "README.md" repository = "https://github.com/kodemore/chili" -version = "2.4.2" +version = "2.5.0" [tool.poetry.dependencies] gaffe = "^0.2.1" diff --git a/tests/test_encoder.py b/tests/test_encoder.py index a2f232c..54ecdee 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -1,3 +1,5 @@ +from collections import UserString + import pytest from chili import Encoder, encodable @@ -18,12 +20,46 @@ class Example: assert instance.__generic__ == Example -def test_fail_encode_non_encodable_type() -> None: +def test_can_encode_non_encodable_type() -> None: # given class Example: name: str age: int + def __init__(self, name: str, age: int): + self.name = name + self.age = age + # when - with pytest.raises(EncoderError.invalid_type): - Encoder[Example]() + encoder = Encoder[Example]() + value = encoder.encode(Example("bob", 33)) + + # then + assert value == { + "name": "bob", + "age": 33, + } + + +def test_can_encode_complex_non_encodable_type() -> None: + # given + class ExampleName(UserString): + pass + + class Example: + name: ExampleName + age: int + + def __init__(self, name: ExampleName, age: int): + self.name = name + self.age = age + + # when + encoder = Encoder[Example]() + value = encoder.encode(Example(ExampleName("bob"), 33)) + + # then + assert value == { + "name": "bob", + "age": 33, + } diff --git a/tests/usecases/userstring_test.py b/tests/usecases/userstring_test.py new file mode 100644 index 0000000..57469ab --- /dev/null +++ b/tests/usecases/userstring_test.py @@ -0,0 +1,32 @@ +from collections import UserString + +from chili import decode, encode + + +def test_can_encode_userstring() -> None: + # given + class ComplexString(UserString): + pass + + string = ComplexString("Example String") + + # when + result = encode(string) + + # then + assert result == "Example String" + + +def test_can_decode_userstring() -> None: + # given + class ComplexString(UserString): + pass + string = "Example String" + + # when + result = decode(string, ComplexString) + + # then + assert result == ComplexString("Example String") + assert isinstance(result, ComplexString) + assert isinstance(result, UserString)