From f4aab6ed3fcd3e70354073784f7b0870863e8105 Mon Sep 17 00:00:00 2001 From: Yukinari Tani Date: Sat, 20 May 2023 11:20:28 +0900 Subject: [PATCH] Fix recursive union --- examples/recursive_union.py | 31 +++++++++++++++++++++++++++++++ examples/runner.py | 2 ++ serde/de.py | 2 +- serde/se.py | 2 +- 4 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 examples/recursive_union.py diff --git a/examples/recursive_union.py b/examples/recursive_union.py new file mode 100644 index 00000000..80f41cbb --- /dev/null +++ b/examples/recursive_union.py @@ -0,0 +1,31 @@ +from typing import List, Union +from serde import serde, to_dict, InternalTagging, from_dict +from dataclasses import dataclass + + +@serde(tagging=InternalTagging("type")) +@dataclass +class Leaf: + value: int + + +@dataclass +class Node: + name: str + children: List[Union[Leaf, "Node"]] + + +serde(Node, tagging=InternalTagging("type")) + + +def main() -> None: + node1 = Node("node1", [Leaf(10)]) + node2 = Node("node2", [node1]) + d = to_dict(node2) + print(f"Into dict: {d}") + node = from_dict(Node, d) + print(f"From dict: {node}") + + +if __name__ == "__main__": + main() diff --git a/examples/runner.py b/examples/runner.py index dd65abfe..c5757d73 100644 --- a/examples/runner.py +++ b/examples/runner.py @@ -25,6 +25,7 @@ import plain_dataclass_class_attribute import recursive import recursive_list +import recursive_union import rename import rename_all import simple @@ -80,6 +81,7 @@ def run_all(): run(alias) run(recursive) run(recursive_list) + run(recursive_union) run(class_var) run(plain_dataclass) run(plain_dataclass_class_attribute) diff --git a/serde/de.py b/serde/de.py index 2d36330e..5a76676b 100644 --- a/serde/de.py +++ b/serde/de.py @@ -275,7 +275,7 @@ def wrap(cls: Type): # We call deserialize and not wrap to make sure that we will use the default serde # configuration for generating the deserialization function. deserialize(typ) - if typ is cls or (is_primitive(typ) and not is_enum(typ)): + if is_primitive(typ) and not is_enum(typ): continue if is_generic(typ): g[typename(typ)] = get_origin(typ) diff --git a/serde/se.py b/serde/se.py index c05e582f..cf3c1a9e 100644 --- a/serde/se.py +++ b/serde/se.py @@ -275,7 +275,7 @@ def wrap(cls: Type[Any]): # configuration for generating the serialization function. serialize(typ) - if typ is cls or (is_primitive(typ) and not is_enum(typ)): + if is_primitive(typ) and not is_enum(typ): continue g[typename(typ)] = typ