Skip to content

Commit

Permalink
Support Literal and Union/Optional types on structs
Browse files Browse the repository at this point in the history
Signed-off-by: Pascal Tomecek <[email protected]>
  • Loading branch information
ptomecek committed Dec 4, 2024
1 parent 277203b commit 43b9e0c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
2 changes: 1 addition & 1 deletion csp/impl/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __new__(cls, name, bases, dct):
# Lists need to be normalized too as potentially we need to add a boolean flag to use FastList
if v == FastList:
raise TypeError(f"{v} annotation is not supported without args")
if CspTypingUtils.is_generic_container(v):
if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v):
actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v)
if CspTypingUtils.is_generic_container(actual_type):
raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]")
Expand Down
20 changes: 18 additions & 2 deletions csp/impl/types/container_type_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,28 @@ def _convert_containers_to_typing_generic_meta(cls, typ, is_within_container):
def normalized_type_to_actual_python_type(cls, typ, level=0):
if isinstance(typ, typing_extensions._AnnotatedAlias):
typ = CspTypingUtils.get_origin(typ)

if CspTypingUtils.is_generic_container(typ):
if CspTypingUtils.get_origin(typ) is FastList and level == 0:
origin = CspTypingUtils.get_origin(typ)
if origin is FastList and level == 0:
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True]
if CspTypingUtils.get_origin(typ) is typing.List and level == 0:
if origin is typing.List and level == 0:
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)]
if origin in (typing.Literal, typing_extensions.Literal): # Not the same in python 3.8/3.9
# Import here to prevent circular import
from csp.impl.types.instantiation_type_resolver import UpcastRegistry

args = typing.get_args(typ)
typ = type(args[0])
for arg in args[1:]:
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
if typ:
return typ
else:
return object
return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ)
elif CspTypingUtils.is_union_type(typ):
return object
else:
return typ

Expand Down
28 changes: 27 additions & 1 deletion csp/tests/impl/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
import unittest
from datetime import date, datetime, time, timedelta
from typing import Dict, List, Set, Tuple
from typing import Dict, List, Literal, Optional, Set, Tuple, Union
from typing_extensions import Annotated

import csp
Expand Down Expand Up @@ -2960,6 +2960,32 @@ class StructWithAnnotations(csp.Struct):
)
self.assertEqual(StructWithAnnotations.metadata(typed=False), {"b": float, "d": dict, "s": str})

def test_literal(self):
class StructWithLiteral(csp.Struct):
s: Literal["foo", "bar"]
f: Literal[0, 1.0]
o: Literal["foo", 0]

self.assertEqual(
StructWithLiteral.metadata(typed=True),
{"s": Literal["foo", "bar"], "f": Literal[0, 1.0], "o": Literal["foo", 0]},
)
self.assertEqual(StructWithLiteral.metadata(typed=False), {"s": str, "f": float, "o": object})

def test_union(self):
class StructWithUnion(csp.Struct):
o1: Union[int, float]
o2: Optional[str]

self.assertEqual(
StructWithUnion.metadata(typed=True),
{
"o1": Union[int, float],
"o2": Optional[str],
},
)
self.assertEqual(StructWithUnion.metadata(typed=False), {"o1": object, "o2": object})


if __name__ == "__main__":
unittest.main()

0 comments on commit 43b9e0c

Please sign in to comment.