diff --git a/Makefile b/Makefile index ac4e9a5..f5b272b 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ pylint: poetry run pylint chili mypy: - poetry run mypy --install-types --non-interactive . + poetry run mypy --install-types --non-interactive chili bandit: poetry run bandit -r . -x ./tests,./test,./.venv diff --git a/chili/decoder.py b/chili/decoder.py index 056eb0b..207009f 100644 --- a/chili/decoder.py +++ b/chili/decoder.py @@ -40,6 +40,7 @@ UNDEFINED, TypeSchema, create_schema, + get_non_optional_fields, get_origin_type, get_parameters_map, get_type_args, @@ -286,7 +287,7 @@ class UnionDecoder(TypeDecoder): } _CASTABLES_TYPES = {decimal.Decimal} - def __init__(self, valid_types: List[Type]): + def __init__(self, valid_types: List[Type], extra_decoders: TypeDecoders = None, force: bool = False): self.valid_types = valid_types self._type_decoders = {} @@ -294,7 +295,11 @@ def __init__(self, valid_types: List[Type]): if a_type in self._PRIMITIVE_TYPES: self._type_decoders[a_type] = a_type continue - self._type_decoders[a_type] = build_type_decoder(a_type) # type: ignore + self._type_decoders[a_type] = build_type_decoder( + a_type, extra_decoders=extra_decoders, force=force # type: ignore + ) + + self.force = force def decode(self, value: Any) -> Any: passed_type = type(value) @@ -317,13 +322,15 @@ def decode(self, value: Any) -> Any: continue if passed_type is dict: - value_keys = value.keys() - for decodable, decoder in self._type_decoders.items(): + provided_fields = set(value.keys()) + for class_name, decoder in self._type_decoders.items(): try: - if not is_decodable(decodable) and value_keys == get_type_hints(decodable).keys(): - return decoder.decode(value) - if is_decodable(decodable) and value_keys == getattr(decodable, _PROPERTIES, {}).keys(): - return decoder.decode(value) + if is_decodable(class_name) or is_dataclass(class_name) or self.force: + expected_fields = set(get_non_optional_fields(class_name)) + if provided_fields.issubset(expected_fields): + return decoder.decode(value) + continue + continue except Exception: continue @@ -497,7 +504,7 @@ def build_type_decoder( return OptionalTypeDecoder( build_type_decoder(a_type=type_args[0], extra_decoders=extra_decoders) # type: ignore ) - return UnionDecoder(type_args) + return UnionDecoder(type_args, extra_decoders=extra_decoders, force=force) if isinstance(a_type, typing.ForwardRef) and module is not None: resolved_reference = resolve_forward_reference(module, a_type) diff --git a/chili/encoder.py b/chili/encoder.py index f730730..ecd3b4e 100644 --- a/chili/encoder.py +++ b/chili/encoder.py @@ -40,7 +40,6 @@ unpack_optional, ) - if sys.version_info >= (3, 10): from types import UnionType else: @@ -341,16 +340,17 @@ def encode(self, value: Any) -> Any: class UnionEncoder(TypeEncoder): - def __init__(self, supported_types: List[Type], extra_encoders: TypeEncoders = None): + def __init__(self, supported_types: List[Type], extra_encoders: TypeEncoders = None, force: bool = False): self.supported_types = supported_types self._extra_encoders = extra_encoders + self.force = force def encode(self, value: Any) -> Any: value_type = type(value) if value_type not in self.supported_types: raise EncoderError.invalid_input - return build_type_encoder(value_type, self._extra_encoders).encode(value) # type: ignore + return build_type_encoder(value_type, self._extra_encoders, force=self.force).encode(value) # type: ignore _supported_generics = { @@ -405,7 +405,7 @@ def build_type_encoder( type_args = get_type_args(a_type) if len(type_args) == 2 and type_args[-1] is type(None): return OptionalTypeEncoder(build_type_encoder(type_args[0], extra_encoders)) # type: ignore - return UnionEncoder(type_args, extra_encoders) + return UnionEncoder(type_args, extra_encoders, force=force) if isinstance(a_type, typing.ForwardRef) and module is not None: resolved_reference = resolve_forward_reference(module, a_type) diff --git a/chili/typing.py b/chili/typing.py index d489038..a4505fc 100644 --- a/chili/typing.py +++ b/chili/typing.py @@ -7,7 +7,7 @@ from enum import Enum from functools import lru_cache from inspect import isclass as is_class -from typing import Any, Callable, ClassVar, Dict, List, NewType, Optional, Type, Union +from typing import Any, Callable, ClassVar, Dict, List, NewType, Optional, Type, Union, get_type_hints from chili.error import SerialisationError @@ -27,6 +27,7 @@ "get_origin_type", "get_parameters_map", "get_type_args", + "get_type_hints", "get_type_parameters", "is_class", "is_dataclass", @@ -43,6 +44,15 @@ ] +def get_non_optional_fields(type_name: Type) -> List[str]: + if is_encodable(type_name): + schema = getattr(type_name, _ENCODABLE) + else: + schema = create_schema(type_name) # type: ignore + + return [field.name for field in schema.values() if not is_optional(field.type)] + + def get_origin_type(type_name: Type) -> Optional[Type]: return getattr(type_name, "__origin__", None) diff --git a/pyproject.toml b/pyproject.toml index 70ca4b1..16aef60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ license = "MIT" name = "chili" readme = "README.md" repository = "https://github.com/kodemore/chili" -version = "2.8.0" +version = "2.8.1" [tool.poetry.dependencies] gaffe = ">=0.3.0" diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 0000000..0a97995 --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional + +from chili.typing import get_non_optional_fields + + +def test_get_non_optional_fields_from_data_class() -> None: + # given + @dataclass + class Person: + name: str + age: int + street_name: Optional[str] + street_number: Optional[int] + + # when + fields = get_non_optional_fields(Person) + + # then + assert fields == ["name", "age"] + + +def test_get_non_optional_fields_from_class() -> None: + class Person: + name: str + age: int + street_name: Optional[str] + street_number: Optional[int] + + # when + fields = get_non_optional_fields(Person) + + # then + assert fields == ["name", "age"] diff --git a/tests/usecases/newtype_test.py b/tests/usecases/newtype_test.py index 8b437e9..2fd10b6 100644 --- a/tests/usecases/newtype_test.py +++ b/tests/usecases/newtype_test.py @@ -1,6 +1,7 @@ +import sys from typing import NewType + from chili import Decoder, Encoder, decodable, encodable -import sys def test_can_encode_newtype_type() -> None: diff --git a/tests/usecases/union_usecase.py b/tests/usecases/union_usecase.py index 6f711e0..5f10924 100644 --- a/tests/usecases/union_usecase.py +++ b/tests/usecases/union_usecase.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Union +from typing import List, Optional, Union from chili import decode, encode @@ -88,3 +88,48 @@ class EmailAddress: # then assert result == {"address": "simple@email.com", "label": ""} + + +def test_can_decode_nested_union() -> None: + # given + @dataclass + class HomeAddress: + home_street: str + number: Optional[int] = 0 + + @dataclass + class OfficeAddress: + office_street: str + number: Optional[int] = 0 + + @dataclass + class Address: + street: str + number: Optional[int] = 0 + + @dataclass + class Person: + name: str + address: HomeAddress | OfficeAddress | Address + + data_office = { + "name": "Bob", + "address": {"office_street": "street"}, + } + + data_home = { + "name": "Bob", + "address": {"home_street": "home"}, + } + + # when + decoded = decode(data_office, Person) + + # then + assert isinstance(decoded.address, OfficeAddress) + + # when + decoded = decode(data_home, Person) + + # then + assert isinstance(decoded.address, HomeAddress)