Skip to content

Commit

Permalink
Merge pull request #542 from yukinarit/fix-flatten-optional
Browse files Browse the repository at this point in the history
Fix flatten for optional
  • Loading branch information
yukinarit authored Jun 9, 2024
2 parents df8db22 + 3c91ee8 commit b564f3f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 9 deletions.
18 changes: 15 additions & 3 deletions serde/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,13 +584,17 @@ def __getitem__(self, n: int) -> Union[DeField[Any], InnerField[Any]]:
"""
typ = type_args(self.type)[n]
opts: dict[str, Any] = {
"kw_only": self.kw_only,
"case": self.case,
"alias": self.alias,
"rename": self.rename,
"skip": self.skip,
"skip_if": self.skip_if,
"skip_if_false": self.skip_if_false,
"skip_if_default": self.skip_if_default,
"serializer": self.serializer,
"deserializer": self.deserializer,
"flatten": self.flatten,
"alias": self.alias,
"parent": self.parent,
}
if is_list(self.type) or is_dict(self.type) or is_set(self.type):
Expand Down Expand Up @@ -856,9 +860,17 @@ def opt(self, arg: DeField[Any]) -> str:
maybe_generic_type_vars=maybe_generic_type_vars, variable_type_args=None, \
reuse_instances=reuse_instances)) if data.get("f") is not None else None'
"""
value_arg = arg[0]
inner = arg[0]
if arg.iterbased:
exists = f"{arg.data} is not None"
elif arg.flatten:
# Check nullabilities of all nested fields.
exists = " and ".join(
[
f'{arg.datavar}.get("{f.name}") is not None'
for f in dataclasses.fields(inner.type)
]
)
else:
name = arg.conv_name()
if arg.alias:
Expand All @@ -867,7 +879,7 @@ def opt(self, arg: DeField[Any]) -> str:
else:
get = f'{arg.datavar}.get("{name}")'
exists = f"{get} is not None"
return f"({self.render(value_arg)}) if {exists} else None"
return f"({self.render(inner)}) if {exists} else None"

def list(self, arg: DeField[Any]) -> str:
"""
Expand Down
30 changes: 25 additions & 5 deletions serde/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,20 @@ def varname(self) -> str:

def __getitem__(self, n: int) -> SeField[Any]:
typ = type_args(self.type)[n]
return SeField(typ, name=None)
opts: dict[str, Any] = {
"kw_only": self.kw_only,
"case": self.case,
"alias": self.alias,
"rename": self.rename,
"skip": self.skip,
"skip_if": self.skip_if,
"skip_if_false": self.skip_if_false,
"skip_if_default": self.skip_if_default,
"serializer": self.serializer,
"deserializer": self.deserializer,
"flatten": self.flatten,
}
return SeField(typ, name=None, **opts)


def sefields(cls: type[Any], serialize_class_var: bool = False) -> Iterator[SeField[Any]]:
Expand Down Expand Up @@ -654,8 +667,12 @@ def render(self, arg: SeField[Any]) -> str:
"""
if is_dataclass(arg.type) and arg.flatten:
return self.flatten(arg)
else:
return f'res["{arg.conv_name(self.case)}"]'
elif is_opt(arg.type) and arg.flatten:
inner = arg[0]
if is_dataclass(inner.type):
return self.flatten(inner)

return f'res["{arg.conv_name(self.case)}"]'

def flatten(self, arg: SeField[Any]) -> str:
"""
Expand Down Expand Up @@ -818,9 +835,12 @@ def opt(self, arg: SeField[Any]) -> str:
"""
if is_bare_opt(arg.type):
return f"{arg.varname} if {arg.varname} is not None else None"

inner = arg[0]
inner.name = arg.varname
if arg.flatten:
return self.render(inner)
else:
inner = arg[0]
inner.name = arg.varname
return f"({self.render(inner)}) if {arg.varname} is not None else None"

def list(self, arg: SeField[Any]) -> str:
Expand Down
19 changes: 18 additions & 1 deletion tests/test_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Tests for flatten attribute.
"""

from typing import Any
from typing import Any, Optional

import pytest

Expand Down Expand Up @@ -51,3 +51,20 @@ class Foo:

f = Foo(a=10, b="foo", bar=Bar(c=100.0, d=True, baz=Baz([1, 2], {"a": "10"})))
assert de(Foo, se(f)) == f


@pytest.mark.parametrize("se,de", all_formats)
def test_flatten_optional(se: Any, de: Any) -> None:
@serde
class Bar:
c: float
d: bool

@serde
class Foo:
a: int
b: str
bar: Optional[Bar] = field(flatten=True)

f = Foo(a=10, b="foo", bar=Bar(c=100.0, d=True))
assert de(Foo, se(f)) == f

0 comments on commit b564f3f

Please sign in to comment.