Skip to content

Commit

Permalink
Merge pull request #28 from kodemore/fix-union-support
Browse files Browse the repository at this point in the history
Add union and newtype support for python > 3.10
  • Loading branch information
dkraczkowski authored Nov 9, 2023
2 parents 5558b91 + 48609e8 commit b0b8ac6
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 4 deletions.
8 changes: 7 additions & 1 deletion chili/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import decimal
import re
import sys
import typing
from abc import abstractmethod
from base64 import b64decode
Expand Down Expand Up @@ -56,6 +57,11 @@
unpack_optional,
)

if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = None

from .error import DecoderError
from .iso_datetime import parse_iso_date, parse_iso_datetime, parse_iso_duration, parse_iso_time
from .mapper import Mapper
Expand Down Expand Up @@ -485,7 +491,7 @@ def build_type_decoder(
if is_class(origin_type) and is_user_string(origin_type):
return SimpleDecoder[origin_type](origin_type) # type: ignore

if origin_type is Union:
if origin_type is Union or (UnionType and isinstance(origin_type, UnionType)):
type_args = get_type_args(a_type)
if len(type_args) == 2 and type_args[-1] is type(None): # type: ignore
return OptionalTypeDecoder(
Expand Down
9 changes: 8 additions & 1 deletion chili/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import decimal
import re
import sys
import typing
from abc import abstractmethod
from base64 import b64encode
Expand Down Expand Up @@ -39,6 +40,12 @@
unpack_optional,
)


if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = None

from .error import EncoderError
from .iso_datetime import timedelta_to_iso_duration
from .mapper import Mapper
Expand Down Expand Up @@ -394,7 +401,7 @@ def build_type_encoder(
if is_class(origin_type) and is_user_string(origin_type):
return SimpleEncoder[str](str)

if origin_type is Union:
if origin_type is Union or (UnionType and isinstance(origin_type, UnionType)):
type_args = get_type_args(a_type)
if len(type_args) == 2 and type_args[-1] is type(None):
return OptionalTypeEncoder(build_type_encoder(type_args[0], extra_encoders)) # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions chili/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@ def is_optional(type_name: Type) -> bool:


def is_newtype(type_name: Type) -> bool:
if sys.version_info >= (3, 10):
return isinstance(type_name, NewType) # type: ignore

if not hasattr(type_name, "__qualname__"):
return False

if type_name.__qualname__ != "NewType.<locals>.new_type":
return False

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ license = "MIT"
name = "chili"
readme = "README.md"
repository = "https://github.com/kodemore/chili"
version = "2.7.1"
version = "2.8.0"

[tool.poetry.dependencies]
gaffe = ">=0.3.0"
Expand Down
2 changes: 1 addition & 1 deletion tests/usecases/newtype_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import NewType

from chili import Decoder, Encoder, decodable, encodable
import sys


def test_can_encode_newtype_type() -> None:
Expand Down
69 changes: 69 additions & 0 deletions tests/usecases/union_usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,72 @@ class Child(Parent):

assert isinstance(after[0], Parent)
assert isinstance(after[1], Child)


def test_union_simple_type_or_dataclass() -> None:
# given
@dataclass
class EmailAddress:
address: str
label: str

value = "[email protected]"
complex_value = {"address": "[email protected]", "label": ""}

# when
result = decode(value, Union[str, EmailAddress])

# then
assert result == "[email protected]"

# when
result = decode(complex_value, Union[str, EmailAddress])

# then
assert isinstance(result, EmailAddress)


def test_new_union_style_decode() -> None:
# given
@dataclass
class EmailAddress:
address: str
label: str

value = "[email protected]"
complex_value = {"address": "[email protected]", "label": ""}

# when
result = decode(value, str | EmailAddress)

# then
assert result == "[email protected]"

# when
result = decode(complex_value, str | EmailAddress)

# then
assert isinstance(result, EmailAddress)


def test_new_union_style_encode() -> None:
# given
@dataclass
class EmailAddress:
address: str
label: str

value = "[email protected]"
complex_value = {"address": "[email protected]", "label": ""}

# when
result = encode(value, str | EmailAddress)

# then
assert result == "[email protected]"

# when
result = encode(EmailAddress(address="[email protected]", label=""), str | EmailAddress)

# then
assert result == {"address": "[email protected]", "label": ""}

0 comments on commit b0b8ac6

Please sign in to comment.