Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Literal and Union/Optional types on structs. Fixes #405 #408

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csp/impl/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __new__(cls, name, bases, dct):
# Lists need to be normalized too as potentially we need to add a boolean flag to use FastList
if v == FastList:
raise TypeError(f"{v} annotation is not supported without args")
if CspTypingUtils.is_generic_container(v):
if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v):
actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v)
if CspTypingUtils.is_generic_container(actual_type):
raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]")
Expand Down
20 changes: 18 additions & 2 deletions csp/impl/types/container_type_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,28 @@ def _convert_containers_to_typing_generic_meta(cls, typ, is_within_container):
def normalized_type_to_actual_python_type(cls, typ, level=0):
if isinstance(typ, typing_extensions._AnnotatedAlias):
typ = CspTypingUtils.get_origin(typ)

if CspTypingUtils.is_generic_container(typ):
if CspTypingUtils.get_origin(typ) is FastList and level == 0:
origin = CspTypingUtils.get_origin(typ)
if origin is FastList and level == 0:
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True]
if CspTypingUtils.get_origin(typ) is typing.List and level == 0:
if origin is typing.List and level == 0:
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)]
if origin in (typing.Literal, typing_extensions.Literal): # Not the same in python 3.8/3.9
# Import here to prevent circular import
from csp.impl.types.instantiation_type_resolver import UpcastRegistry

args = typing.get_args(typ)
typ = type(args[0])
for arg in args[1:]:
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
if typ:
return typ
else:
return object
return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ)
elif CspTypingUtils.is_union_type(typ):
return object
else:
return typ

Expand Down
37 changes: 36 additions & 1 deletion csp/tests/impl/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
import unittest
from datetime import date, datetime, time, timedelta
from typing import Dict, List, Set, Tuple
from typing import Dict, List, Literal, Optional, Set, Tuple, Union
from typing_extensions import Annotated

import csp
Expand Down Expand Up @@ -2960,6 +2960,41 @@ class StructWithAnnotations(csp.Struct):
)
self.assertEqual(StructWithAnnotations.metadata(typed=False), {"b": float, "d": dict, "s": str})

def test_literal(self):
simple_class0 = SimpleClass(0)
simple_class1 = SimpleClass(1)

class StructWithLiteral(csp.Struct):
s: Literal["foo", "bar"]
f: Literal[0, 1.0]
o: Literal["foo", 0]
c: Literal[simple_class0, simple_class1]

self.assertEqual(
StructWithLiteral.metadata(typed=True),
{
"s": Literal["foo", "bar"],
"f": Literal[0, 1.0],
"o": Literal["foo", 0],
"c": Literal[simple_class0, simple_class1],
},
)
self.assertEqual(StructWithLiteral.metadata(typed=False), {"s": str, "f": float, "o": object, "c": SimpleClass})

def test_union(self):
class StructWithUnion(csp.Struct):
o1: Union[int, float]
o2: Optional[str]

self.assertEqual(
StructWithUnion.metadata(typed=True),
{
"o1": Union[int, float],
"o2": Optional[str],
},
)
self.assertEqual(StructWithUnion.metadata(typed=False), {"o1": object, "o2": object})


if __name__ == "__main__":
unittest.main()
Loading