Skip to content

Commit

Permalink
Merge pull request #9 from seandstewart/seandstewart/type-alias-type-…
Browse files Browse the repository at this point in the history
…origin

fix: Track "unwrapped" types during routine resolution
  • Loading branch information
seandstewart authored Oct 30, 2024
2 parents 1021b39 + a05be1e commit f010baf
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 33 deletions.
38 changes: 20 additions & 18 deletions src/typelib/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def static_order(
"""
# We want to leverage the cache if possible, hence the recursive call.
# Shouldn't actually recurse more than once or twice.
if inspection.istypealiastype(t):
value = inspection.unwrap(t)
return static_order(value)
if isinstance(t, (str, refs.ForwardRef)):
ref = refs.forwardref(t) if isinstance(t, str) else t
t = refs.evaluate(ref)
Expand Down Expand Up @@ -80,8 +77,6 @@ def itertypes(
[`static_order`][typelib.graph.static_order] instead of
[`itertypes`][typelib.graph.itertypes].
"""
if inspection.istypealiastype(t):
t = inspection.unwrap(t)
if isinstance(t, (str, refs.ForwardRef)): # pragma: no cover
ref = refs.forwardref(t) if isinstance(t, str) else t
t = refs.evaluate(ref)
Expand Down Expand Up @@ -113,33 +108,31 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]:
resolve one level deep on each attempt, otherwise we will find ourselves stuck
in a closed loop which never terminates (infinite recursion).
"""
if inspection.istypealiastype(t):
t = inspection.unwrap(t)

graph: graphlib.TopologicalSorter = graphlib.TopologicalSorter()
root = TypeNode(t)
u = inspection.unwrap(t)
root = TypeNode(t, u)
stack = collections.deque([root])
visited = {root.type}
while stack:
parent = stack.popleft()
if inspection.isliteral(parent.type):
parent_unwrapped = inspection.unwrap(parent.type)
if inspection.isliteral(parent_unwrapped):
graph.add(parent)
continue

predecessors = []
for var, child in _level(parent.type):
for var, child in _level(parent_unwrapped):
# If no type was provided, there's no reason to do further processing.
if child in (constants.empty, typing.Any):
continue
if inspection.istypealiastype(child):
child = inspection.unwrap(child)

unwrapped = inspection.unwrap(child)
# Only subscripted generics or non-stdlib types can be cyclic.
# i.e., we may get `str` or `datetime` any number of times,
# that's not cyclic, so we can just add it to the graph.
is_visited = child in visited
is_subscripted = inspection.issubscriptedgeneric(child)
is_stdlib = inspection.isstdlibtype(child)
is_visited = child in visited or unwrapped in visited
is_subscripted = inspection.issubscriptedgeneric(unwrapped)
is_stdlib = inspection.isstdlibtype(unwrapped)
can_be_cyclic = is_subscripted or is_stdlib is False
# We detected a cyclic type,
# wrap in a ForwardRef and don't add it to the stack
Expand All @@ -155,10 +148,13 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]:
ref = refs.forwardref(
refname, is_argument=is_argument, module=module, is_class=is_class
)
node = TypeNode(ref, var=var, cyclic=True)
uref = refs.forwardref(
unwrapped, is_argument=is_argument, module=module, is_class=is_class
)
node = TypeNode(ref, uref, var=var, cyclic=True)
# Otherwise, add the type to the stack and track that it's been seen.
else:
node = TypeNode(type=child, var=var)
node = TypeNode(type=child, unwrapped=unwrapped, var=var)
visited.add(node.type)
stack.append(node)
# Flag the type as a "predecessor" of the parent type.
Expand All @@ -177,11 +173,17 @@ class TypeNode:

type: typing.Any
"""The type annotation for this node."""
unwrapped: typing.Any | None = None
"""The unwrapped type annotation for this node."""
var: str | None = None
"""The variable or parameter name associated to the type annotation for this node."""
cyclic: bool = dataclasses.field(default=False, hash=False, compare=False)
"""Whether this type annotation is cyclic."""

def __post_init__(self):
if self.unwrapped is None:
self.unwrapped = self.type


def _level(t: typing.Any) -> typing.Iterable[tuple[str | None, type]]:
args = inspection.args(t)
Expand Down
9 changes: 6 additions & 3 deletions src/typelib/marshals/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def marshaller(
root = nodes[-1]
for node in nodes:
context[node.type] = _get_unmarshaller(node, context=context)
context[node.unwrapped] = context[node.type]

return context[root.type]

Expand All @@ -66,10 +67,12 @@ def _get_unmarshaller( # type: ignore[return]
return context[node.type]

for check, unmarshaller_cls in _HANDLERS.items():
if check(node.type):
return unmarshaller_cls(node.type, context=context, var=node.var)
if check(node.unwrapped):
return unmarshaller_cls(node.unwrapped, context=context, var=node.var)

return routines.StructuredTypeMarshaller(node.type, context=context, var=node.var)
return routines.StructuredTypeMarshaller(
node.unwrapped, context=context, var=node.var
)


class DelayedMarshaller(routines.AbstractMarshaller[T]):
Expand Down
2 changes: 1 addition & 1 deletion src/typelib/py/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def isoptionaltype(obj: type[_OT]) -> compat.TypeIs[type[tp.Optional[_OT]]]:
tname = name(origin(obj))
nullarg = next((a for a in args if a in (type(None), None)), ...)
isoptional = tname == "Optional" or (
nullarg is not ... and tname in ("Union", "Uniontype", "Literal")
nullarg is not ... and tname in ("Union", "UnionType", "Literal")
)
return isoptional

Expand Down
9 changes: 6 additions & 3 deletions src/typelib/unmarshals/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def unmarshaller(
root = nodes[-1]
for node in nodes:
context[node.type] = _get_unmarshaller(node, context=context)
context[node.unwrapped] = context[node.type]

return context[root.type]

Expand All @@ -60,10 +61,12 @@ def _get_unmarshaller( # type: ignore[return]
return context[node.type]

for check, unmarshaller_cls in _HANDLERS.items():
if check(node.type):
return unmarshaller_cls(node.type, context=context, var=node.var)
if check(node.unwrapped):
return unmarshaller_cls(node.unwrapped, context=context, var=node.var)

return routines.StructuredTypeUnmarshaller(node.type, context=context, var=node.var)
return routines.StructuredTypeUnmarshaller(
node.unwrapped, context=context, var=node.var
)


class DelayedUnmarshaller(routines.AbstractUnmarshaller[T]):
Expand Down
5 changes: 5 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@ class NestedTypeAliasType:
RecursiveAlias = compat.TypeAliasType(
"RecursiveAlias", "dict[str, RecursiveAlias | ValueAlias]"
)

ScalarValue = compat.TypeAliasType("ScalarValue", "int | float | str | bool | None")
Record = compat.TypeAliasType(
"Record", "dict[str, list[Record] | list[ScalarValue] | Record | ScalarValue]"
)
9 changes: 2 additions & 7 deletions tests/unit/marshals/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import pytest

from typelib import graph
from typelib.marshals import routines

from tests import models
Expand Down Expand Up @@ -405,12 +404,8 @@ def test_fixed_tuple_marshaller(
@pytest.mark.suite(
context=dict(
given_context={
graph.TypeNode(int, var="value"): routines.IntegerMarshaller(
int, {}, var="value"
),
graph.TypeNode(str, var="field"): routines.StringMarshaller(
str, {}, var="field"
),
int: routines.IntegerMarshaller(int, {}, var="value"),
str: routines.StringMarshaller(str, {}, var="field"),
},
expected_output=dict(field="data", value=1),
),
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class NoTypes:
given_type=models.NestedTypeAliasType,
expected_nodes=[
graph.TypeNode(type=int),
graph.TypeNode(type=list[int], var="alias"),
graph.TypeNode(type=models.ListAlias, unwrapped=list[int], var="alias"),
graph.TypeNode(type=NestedTypeAliasType),
],
),
Expand Down

0 comments on commit f010baf

Please sign in to comment.