diff --git a/pyproject.toml b/pyproject.toml index 03a0750..b6d9f4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,6 @@ main = ["poetry==1.7.1"] [tool.mypy] python_version = 3.9 packages = ["basedtyping", "tests"] -# we can't use override until we bump the minimum typing_extensions or something -disable_error_code = ["explicit-override"] [tool.ruff.format] skip-magic-trailing-comma = true diff --git a/src/basedtyping/__init__.py b/src/basedtyping/__init__.py index c4aaa85..5f66757 100644 --- a/src/basedtyping/__init__.py +++ b/src/basedtyping/__init__.py @@ -27,7 +27,8 @@ ) import typing_extensions -from typing_extensions import Never, ParamSpec, Self, TypeAlias, TypeGuard, TypeVarTuple +from typing_extensions import Never, ParamSpec, Self, TypeAlias, TypeGuard, \ + TypeVarTuple, override from basedtyping import transformer from basedtyping.runtime_only import OldUnionType @@ -69,6 +70,7 @@ class _BasedSpecialForm(_SpecialForm, _root=True): # type: ignore[misc] _name: str + @override def __init_subclass__(cls, _root=False): # noqa: FBT002 super().__init_subclass__(_root=_root) # type: ignore[call-arg] @@ -76,6 +78,7 @@ def __init__(self, *args: object, **kwargs: object): self.alias = kwargs.pop("alias", _BasedGenericAlias) super().__init__(*args, **kwargs) + @override def __repr__(self) -> str: return "basedtyping." + self._name @@ -244,6 +247,7 @@ def _is_subclass(cls, subclass: object) -> TypeGuard[_ReifiedGenericMetaclass]: cast(_ReifiedGenericMetaclass, subclass)._orig_class(), ) + @override def __subclasscheck__(cls, subclass: object) -> bool: if not cls._is_subclass(subclass): return False @@ -264,6 +268,7 @@ def __subclasscheck__(cls, subclass: object) -> bool: subclass._check_generics_reified() return cls._type_var_check(subclass.__reified_generics__) + @override def __instancecheck__(cls, instance: object) -> bool: if not cls._is_subclass(type(instance)): return False @@ -272,6 +277,7 @@ def __instancecheck__(cls, instance: object) -> bool: return cls._type_var_check(cast(ReifiedGeneric[object], instance).__reified_generics__) # need the generic here for pyright. see https://github.com/microsoft/pyright/issues/5488 + @override def __call__(cls: type[T], *args: object, **kwargs: object) -> T: """A placeholder ``__call__`` method that gets called when the class is instantiated directly, instead of first supplying the type parameters. @@ -400,6 +406,7 @@ def __class_getitem__( # type: ignore[no-any-decorated] reified_generic_copy._can_do_instance_and_subclass_checks_without_generics = False return reified_generic_copy + @override def __init_subclass__(cls): cls._can_do_instance_and_subclass_checks_without_generics = True super().__init_subclass__() @@ -473,14 +480,17 @@ def Untyped( # noqa: N802 class _IntersectionGenericAlias(_BasedGenericAlias, _root=True): + @override def copy_with(self, args: object) -> Self: # type: ignore[override] # TODO: put in the overloads # noqa: TD003 return cast(Self, Intersection[args]) + @override def __eq__(self, other: object) -> bool: if not isinstance(other, _IntersectionGenericAlias): return NotImplemented return set(self.__args__) == set(other.__args__) + @override def __hash__(self) -> int: return hash(frozenset(self.__args__)) @@ -490,6 +500,7 @@ def __instancecheck__(self, obj: object) -> bool: def __subclasscheck__(self, cls: type[object]) -> bool: return all(issubclass(cls, arg) for arg in self.__args__) + @override def __reduce__(self) -> (object, object): func, (_, args) = super().__reduce__() # type: ignore[no-any-expr, misc] return func, (Intersection, args) @@ -538,10 +549,13 @@ def Intersection(self: _BasedSpecialForm, parameters: object) -> object: # noqa class _TypeFormForm(_BasedSpecialForm, _root=True): # type: ignore[misc] + # TODO: decorator-ify def __init__(self, doc: str): + super().__init__() self._name = "TypeForm" self._doc = self.__doc__ = doc + @override def __getitem__(self, parameters: object | tuple[object]) -> _BasedGenericAlias: if not isinstance(parameters, tuple): parameters = (parameters,) @@ -610,8 +624,8 @@ def __init__(self, arg: str, *, is_argument=True, module: object = None, is_clas except SyntaxError: # Callable: () -> int if "->" not in arg_to_compile: - raise RuntimeError( - f"expected a callable type, but found... what is {arg_to_compile}?" + raise SyntaxError( + f"invalid syntax in ForwardRef: {arg_to_compile}?" ) from None code = compile("'un-representable callable type'", "", "eval") @@ -625,6 +639,7 @@ def __init__(self, arg: str, *, is_argument=True, module: object = None, is_clas if sys.version_info >= (3, 13): + @override def _evaluate( self, globalns: dict[str, object] | None, @@ -637,6 +652,7 @@ def _evaluate( elif sys.version_info >= (3, 12): + @override def _evaluate( self, globalns: dict[str, object] | None, @@ -649,6 +665,7 @@ def _evaluate( else: + @override def _evaluate( self, globalns: dict[str, object] | None, diff --git a/src/basedtyping/transformer.py b/src/basedtyping/transformer.py index 9ec8aa1..99ed39a 100644 --- a/src/basedtyping/transformer.py +++ b/src/basedtyping/transformer.py @@ -9,13 +9,14 @@ from contextlib import contextmanager from enum import Enum from typing import cast +from typing_extensions import override import typing_extensions import basedtyping -# ruff: noqa: N802, S101 +# rasdfuff: noqa: N802, S101 class CringeTransformer(ast.NodeTransformer): """Transforms `1 | 2` into `Literal[1] | Literal[2]` etc""" @@ -48,6 +49,7 @@ def __init__( self.basedtyping_name: basedtyping, } + @override def visit(self, node: ast.AST) -> ast.AST: return cast(ast.AST, super().visit(node)) @@ -107,6 +109,7 @@ def implicit_tuple(self, *, value=True) -> typing.Iterator[None]: finally: self._implicit_tuple = implicit_tuple + @override def visit_Subscript(self, node: ast.Subscript) -> ast.AST: node_type = self.eval_type(node.value) if self.eval_type(node.value) is typing_extensions.Literal: @@ -133,6 +136,7 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.AST: node = self.subscript(self._typing("Callable"), slice2_) return node + @override def visit_Attribute(self, node: ast.Attribute) -> ast.AST: node = self.generic_visit(node) assert isinstance(node, ast.expr) @@ -142,12 +146,14 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: return self._literal(node) return node + @override def visit_Name(self, node: ast.Name) -> ast.AST: name_type = self.eval_type(node) if isinstance(name_type, Enum): return self._literal(node) return node + @override def visit_Constant(self, node: ast.Constant) -> ast.AST: value = cast(object, node.value) if not self.string_literals and isinstance(value, str): @@ -156,6 +162,7 @@ def visit_Constant(self, node: ast.Constant) -> ast.AST: return self._literal(node) return node + @override def visit_Tuple(self, node: ast.Tuple) -> ast.AST: with self.implicit_tuple(value=False): result = self.generic_visit(node) @@ -163,6 +170,7 @@ def visit_Tuple(self, node: ast.Tuple) -> ast.AST: return self.subscript(self._typing("Tuple"), cast(ast.expr, result)) return result + @override def visit_Compare(self, node: ast.Compare) -> ast.AST: if len(node.ops) == 1 and isinstance(node.ops[0], ast.Is): result = self.subscript( @@ -171,6 +179,7 @@ def visit_Compare(self, node: ast.Compare) -> ast.AST: return self.generic_visit(result) return self.generic_visit(node) + @override def visit_IfExp(self, node: ast.IfExp) -> ast.AST: if ( isinstance(node.body, ast.Compare) @@ -192,6 +201,7 @@ def visit_FunctionType(self, node: ast.FunctionType) -> ast.AST: ast.Tuple([ast.List(node.argtypes, ctx=ast.Load()), node.returns], ctx=ast.Load()), ) ) + @override def visit_BinOp(self, node: ast.BinOp) -> ast.AST: node = self.generic_visit(node)