Skip to content

Commit

Permalink
Merge pull request #29 from kodemore/fix-union-force-mode
Browse files Browse the repository at this point in the history
Add force mode to unions and hotfix union encoding
  • Loading branch information
dkraczkowski authored Nov 14, 2023
2 parents b0b8ac6 + 5f583f7 commit e537a8f
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pylint:
poetry run pylint chili

mypy:
poetry run mypy --install-types --non-interactive .
poetry run mypy --install-types --non-interactive chili

bandit:
poetry run bandit -r . -x ./tests,./test,./.venv
Expand Down
25 changes: 16 additions & 9 deletions chili/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
UNDEFINED,
TypeSchema,
create_schema,
get_non_optional_fields,
get_origin_type,
get_parameters_map,
get_type_args,
Expand Down Expand Up @@ -286,15 +287,19 @@ class UnionDecoder(TypeDecoder):
}
_CASTABLES_TYPES = {decimal.Decimal}

def __init__(self, valid_types: List[Type]):
def __init__(self, valid_types: List[Type], extra_decoders: TypeDecoders = None, force: bool = False):
self.valid_types = valid_types
self._type_decoders = {}

for a_type in valid_types:
if a_type in self._PRIMITIVE_TYPES:
self._type_decoders[a_type] = a_type
continue
self._type_decoders[a_type] = build_type_decoder(a_type) # type: ignore
self._type_decoders[a_type] = build_type_decoder(
a_type, extra_decoders=extra_decoders, force=force # type: ignore
)

self.force = force

def decode(self, value: Any) -> Any:
passed_type = type(value)
Expand All @@ -317,13 +322,15 @@ def decode(self, value: Any) -> Any:
continue

if passed_type is dict:
value_keys = value.keys()
for decodable, decoder in self._type_decoders.items():
provided_fields = set(value.keys())
for class_name, decoder in self._type_decoders.items():
try:
if not is_decodable(decodable) and value_keys == get_type_hints(decodable).keys():
return decoder.decode(value)
if is_decodable(decodable) and value_keys == getattr(decodable, _PROPERTIES, {}).keys():
return decoder.decode(value)
if is_decodable(class_name) or is_dataclass(class_name) or self.force:
expected_fields = set(get_non_optional_fields(class_name))
if provided_fields.issubset(expected_fields):
return decoder.decode(value)
continue
continue
except Exception:
continue

Expand Down Expand Up @@ -497,7 +504,7 @@ def build_type_decoder(
return OptionalTypeDecoder(
build_type_decoder(a_type=type_args[0], extra_decoders=extra_decoders) # type: ignore
)
return UnionDecoder(type_args)
return UnionDecoder(type_args, extra_decoders=extra_decoders, force=force)

if isinstance(a_type, typing.ForwardRef) and module is not None:
resolved_reference = resolve_forward_reference(module, a_type)
Expand Down
8 changes: 4 additions & 4 deletions chili/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
unpack_optional,
)


if sys.version_info >= (3, 10):
from types import UnionType
else:
Expand Down Expand Up @@ -341,16 +340,17 @@ def encode(self, value: Any) -> Any:


class UnionEncoder(TypeEncoder):
def __init__(self, supported_types: List[Type], extra_encoders: TypeEncoders = None):
def __init__(self, supported_types: List[Type], extra_encoders: TypeEncoders = None, force: bool = False):
self.supported_types = supported_types
self._extra_encoders = extra_encoders
self.force = force

def encode(self, value: Any) -> Any:
value_type = type(value)
if value_type not in self.supported_types:
raise EncoderError.invalid_input

return build_type_encoder(value_type, self._extra_encoders).encode(value) # type: ignore
return build_type_encoder(value_type, self._extra_encoders, force=self.force).encode(value) # type: ignore


_supported_generics = {
Expand Down Expand Up @@ -405,7 +405,7 @@ def build_type_encoder(
type_args = get_type_args(a_type)
if len(type_args) == 2 and type_args[-1] is type(None):
return OptionalTypeEncoder(build_type_encoder(type_args[0], extra_encoders)) # type: ignore
return UnionEncoder(type_args, extra_encoders)
return UnionEncoder(type_args, extra_encoders, force=force)

if isinstance(a_type, typing.ForwardRef) and module is not None:
resolved_reference = resolve_forward_reference(module, a_type)
Expand Down
12 changes: 11 additions & 1 deletion chili/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum
from functools import lru_cache
from inspect import isclass as is_class
from typing import Any, Callable, ClassVar, Dict, List, NewType, Optional, Type, Union
from typing import Any, Callable, ClassVar, Dict, List, NewType, Optional, Type, Union, get_type_hints

from chili.error import SerialisationError

Expand All @@ -27,6 +27,7 @@
"get_origin_type",
"get_parameters_map",
"get_type_args",
"get_type_hints",
"get_type_parameters",
"is_class",
"is_dataclass",
Expand All @@ -43,6 +44,15 @@
]


def get_non_optional_fields(type_name: Type) -> List[str]:
if is_encodable(type_name):
schema = getattr(type_name, _ENCODABLE)
else:
schema = create_schema(type_name) # type: ignore

return [field.name for field in schema.values() if not is_optional(field.type)]


def get_origin_type(type_name: Type) -> Optional[Type]:
return getattr(type_name, "__origin__", None)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ license = "MIT"
name = "chili"
readme = "README.md"
repository = "https://github.com/kodemore/chili"
version = "2.8.0"
version = "2.8.1"

[tool.poetry.dependencies]
gaffe = ">=0.3.0"
Expand Down
34 changes: 34 additions & 0 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import Optional

from chili.typing import get_non_optional_fields


def test_get_non_optional_fields_from_data_class() -> None:
# given
@dataclass
class Person:
name: str
age: int
street_name: Optional[str]
street_number: Optional[int]

# when
fields = get_non_optional_fields(Person)

# then
assert fields == ["name", "age"]


def test_get_non_optional_fields_from_class() -> None:
class Person:
name: str
age: int
street_name: Optional[str]
street_number: Optional[int]

# when
fields = get_non_optional_fields(Person)

# then
assert fields == ["name", "age"]
3 changes: 2 additions & 1 deletion tests/usecases/newtype_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
from typing import NewType

from chili import Decoder, Encoder, decodable, encodable
import sys


def test_can_encode_newtype_type() -> None:
Expand Down
47 changes: 46 additions & 1 deletion tests/usecases/union_usecase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Union

from chili import decode, encode

Expand Down Expand Up @@ -88,3 +88,48 @@ class EmailAddress:

# then
assert result == {"address": "[email protected]", "label": ""}


def test_can_decode_nested_union() -> None:
# given
@dataclass
class HomeAddress:
home_street: str
number: Optional[int] = 0

@dataclass
class OfficeAddress:
office_street: str
number: Optional[int] = 0

@dataclass
class Address:
street: str
number: Optional[int] = 0

@dataclass
class Person:
name: str
address: HomeAddress | OfficeAddress | Address

data_office = {
"name": "Bob",
"address": {"office_street": "street"},
}

data_home = {
"name": "Bob",
"address": {"home_street": "home"},
}

# when
decoded = decode(data_office, Person)

# then
assert isinstance(decoded.address, OfficeAddress)

# when
decoded = decode(data_home, Person)

# then
assert isinstance(decoded.address, HomeAddress)

0 comments on commit e537a8f

Please sign in to comment.