From 43b9e0c06351b9f84ef2913122b15424b201ffa0 Mon Sep 17 00:00:00 2001 From: Pascal Tomecek Date: Tue, 3 Dec 2024 17:42:46 -0500 Subject: [PATCH] Support Literal and Union/Optional types on structs Signed-off-by: Pascal Tomecek --- csp/impl/struct.py | 2 +- csp/impl/types/container_type_normalizer.py | 20 +++++++++++++-- csp/tests/impl/test_struct.py | 28 ++++++++++++++++++++- 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/csp/impl/struct.py b/csp/impl/struct.py index 777cca8d4..edb65c262 100644 --- a/csp/impl/struct.py +++ b/csp/impl/struct.py @@ -34,7 +34,7 @@ def __new__(cls, name, bases, dct): # Lists need to be normalized too as potentially we need to add a boolean flag to use FastList if v == FastList: raise TypeError(f"{v} annotation is not supported without args") - if CspTypingUtils.is_generic_container(v): + if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v): actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v) if CspTypingUtils.is_generic_container(actual_type): raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]") diff --git a/csp/impl/types/container_type_normalizer.py b/csp/impl/types/container_type_normalizer.py index 9b4d50360..fbe4d5c63 100644 --- a/csp/impl/types/container_type_normalizer.py +++ b/csp/impl/types/container_type_normalizer.py @@ -73,12 +73,28 @@ def _convert_containers_to_typing_generic_meta(cls, typ, is_within_container): def normalized_type_to_actual_python_type(cls, typ, level=0): if isinstance(typ, typing_extensions._AnnotatedAlias): typ = CspTypingUtils.get_origin(typ) + if CspTypingUtils.is_generic_container(typ): - if CspTypingUtils.get_origin(typ) is FastList and level == 0: + origin = CspTypingUtils.get_origin(typ) + if origin is FastList and level == 0: return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True] - if CspTypingUtils.get_origin(typ) is typing.List and level == 0: + if origin is typing.List and level == 0: return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)] + if origin in (typing.Literal, typing_extensions.Literal): # Not the same in python 3.8/3.9 + # Import here to prevent circular import + from csp.impl.types.instantiation_type_resolver import UpcastRegistry + + args = typing.get_args(typ) + typ = type(args[0]) + for arg in args[1:]: + typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False) + if typ: + return typ + else: + return object return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ) + elif CspTypingUtils.is_union_type(typ): + return object else: return typ diff --git a/csp/tests/impl/test_struct.py b/csp/tests/impl/test_struct.py index 1a5605b13..d09249bca 100644 --- a/csp/tests/impl/test_struct.py +++ b/csp/tests/impl/test_struct.py @@ -6,7 +6,7 @@ import typing import unittest from datetime import date, datetime, time, timedelta -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Literal, Optional, Set, Tuple, Union from typing_extensions import Annotated import csp @@ -2960,6 +2960,32 @@ class StructWithAnnotations(csp.Struct): ) self.assertEqual(StructWithAnnotations.metadata(typed=False), {"b": float, "d": dict, "s": str}) + def test_literal(self): + class StructWithLiteral(csp.Struct): + s: Literal["foo", "bar"] + f: Literal[0, 1.0] + o: Literal["foo", 0] + + self.assertEqual( + StructWithLiteral.metadata(typed=True), + {"s": Literal["foo", "bar"], "f": Literal[0, 1.0], "o": Literal["foo", 0]}, + ) + self.assertEqual(StructWithLiteral.metadata(typed=False), {"s": str, "f": float, "o": object}) + + def test_union(self): + class StructWithUnion(csp.Struct): + o1: Union[int, float] + o2: Optional[str] + + self.assertEqual( + StructWithUnion.metadata(typed=True), + { + "o1": Union[int, float], + "o2": Optional[str], + }, + ) + self.assertEqual(StructWithUnion.metadata(typed=False), {"o1": object, "o2": object}) + if __name__ == "__main__": unittest.main()