Skip to content

Commit 706fb44

Browse files
committed
generic function
1 parent f3c2a9b commit 706fb44

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

basedtyping/__init__.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import sys
8+
from dataclasses import dataclass
89
from typing import ( # type: ignore[attr-defined]
910
TYPE_CHECKING,
1011
Any,
@@ -53,6 +54,7 @@
5354
"Untyped",
5455
"Intersection",
5556
"TypeForm",
57+
"generic",
5658
)
5759

5860
if TYPE_CHECKING:
@@ -508,7 +510,9 @@ def __reduce__(self) -> (object, object):
508510
if sys.version_info > (3, 9):
509511

510512
@_BasedSpecialForm
511-
def Intersection(self: _BasedSpecialForm, parameters: object) -> object: # noqa: N802
513+
def Intersection( # noqa: N802
514+
self: _BasedSpecialForm, parameters: object
515+
) -> object:
512516
"""Intersection type; Intersection[X, Y] means both X and Y.
513517
514518
To define an intersection:
@@ -574,3 +578,64 @@ def f[T](t: TypeForm[T]) -> T: ...
574578
reveal_type(f(int | str)) # int | str
575579
"""
576580
)
581+
582+
583+
@dataclass
584+
class _BaseGenericFunction(Generic[P, T]):
585+
fn: Callable[P, T]
586+
587+
588+
@dataclass
589+
class _GenericFunction(_BaseGenericFunction[P, T]):
590+
# TODO: make this an TypeVarTuple when mypy supports it
591+
# https://github.com/python/mypy/issues/16696
592+
__type_params__: tuple[object, ...] | None = None
593+
"""Generic type parameters. Currently unused"""
594+
595+
def __getitem__(self, items: object) -> _ConcreteFunction[P, T]:
596+
items = items if isinstance(items, tuple) else (items,)
597+
return _ConcreteFunction(self.fn, items)
598+
599+
600+
@dataclass
601+
class _ConcreteFunction(_BaseGenericFunction[P, T]):
602+
__type_args__: tuple[object, ...] | None = None
603+
"""Concrete type parameters. Currently unused"""
604+
605+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
606+
return self.fn(*args, **kwargs)
607+
608+
609+
class _GenericFunctionFacilitator:
610+
__type_params__: tuple[object, ...] | None = None
611+
args: tuple[object, ...]
612+
613+
def __call__(self, fn: Callable[P, T]) -> _GenericFunction[P, T]:
614+
return _GenericFunction(fn, self.args)
615+
616+
617+
class _GenericFunctionDecorator:
618+
"""Decorate a function to allow supplying type parameters on calls:
619+
620+
@generic[T]
621+
def f1(t: T): ...
622+
623+
f1[int](1)
624+
625+
@generic
626+
def f2[T](t: T): ...
627+
628+
f2[int](1)
629+
"""
630+
631+
def __call__(self, fn: Callable[P, T]) -> _GenericFunction[P, T]:
632+
params = cast(Union[Tuple[object, ...], None], getattr(fn, "__type_params__", None))
633+
return _GenericFunction(fn, params)
634+
635+
def __getitem__(self, items: object) -> _GenericFunctionFacilitator:
636+
result = _GenericFunctionFacilitator()
637+
result.args = items if isinstance(items, tuple) else (items,)
638+
return result
639+
640+
641+
generic = _GenericFunctionDecorator()

tests/test_generic.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from typing import Callable, cast
5+
6+
import pytest
7+
8+
from basedtyping import T, generic
9+
10+
11+
def test_generic_with_args():
12+
deco = generic[T]
13+
14+
@deco # Python version 3.8 does not support arbitrary expressions as a decorator
15+
def f(t: T) -> T:
16+
return t
17+
18+
assert f.__type_params__ == (T,)
19+
assert f[int].__type_args__ == (int,)
20+
assert f[object](1) == 1
21+
22+
23+
@pytest.mark.skipif(sys.version_info < (3, 12), reason="Needs generic syntax support")
24+
def test_generic_without_args():
25+
local: dict[str, object] = {}
26+
# Can't use the actual function because then <3.12 wouldn't load
27+
exec("def f[T](t: T) -> T: return t", None, local)
28+
_f = cast(Callable[[object], object], local["f"])
29+
f = generic(_f)
30+
assert f.__type_params__ == _f.__type_params__ # type: ignore[attr-defined]
31+
assert f[int].__type_args__ == (int,)
32+
assert f[int](1) == 1

tests/test_generic_312.py

Whitespace-only changes.

0 commit comments

Comments
 (0)