Skip to content

Commit eba432a

Browse files
committed
check function argument names for subtypes
1 parent 4f6e90f commit eba432a

27 files changed

+156
-136
lines changed

.mypy/baseline.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,7 +1977,7 @@
19771977
"code": "explicit-override",
19781978
"column": 4,
19791979
"message": "Method \"visit_class_def\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
1980-
"offset": 432,
1980+
"offset": 435,
19811981
"src": "def visit_class_def(self, defn: ClassDef) -> None:",
19821982
"target": "mypy.checker.TypeChecker.visit_class_def"
19831983
},
@@ -2353,7 +2353,7 @@
23532353
"code": "explicit-override",
23542354
"column": 4,
23552355
"message": "Method \"visit_uninhabited_type\" is not using @override but is overriding a method in class \"mypy.type_visitor.BoolTypeQuery\"",
2356-
"offset": 197,
2356+
"offset": 192,
23572357
"src": "def visit_uninhabited_type(self, t: UninhabitedType) -> bool:",
23582358
"target": "mypy.checker.InvalidInferredTypes.visit_uninhabited_type"
23592359
},
@@ -10771,7 +10771,7 @@
1077110771
"code": "no-any-expr",
1077210772
"column": 67,
1077310773
"message": "Expression type contains \"Any\" (has type \"list[Any]\")",
10774-
"offset": 285,
10774+
"offset": 289,
1077510775
"src": "\"--package-root\", metavar=\"ROOT\", action=\"append\", default=[], help=argparse.SUPPRESS",
1077610776
"target": "mypy.main.process_options"
1077710777
},
@@ -11429,7 +11429,7 @@
1142911429
"code": "explicit-override",
1143011430
"column": 4,
1143111431
"message": "Method \"visit_unbound_type\" is not using @override but is overriding a method in class \"mypy.type_visitor.TypeVisitor\"",
11432-
"offset": 642,
11432+
"offset": 638,
1143311433
"src": "def visit_unbound_type(self, t: UnboundType) -> ProperType:",
1143411434
"target": "mypy.meet.TypeMeetVisitor.visit_unbound_type"
1143511435
},
@@ -14119,7 +14119,7 @@
1411914119
"code": "no-any-expr",
1412014120
"column": 18,
1412114121
"message": "Expression has type \"Any\"",
14122-
"offset": 122,
14122+
"offset": 125,
1412314123
"src": "MACHDEP = sysconfig.get_config_var(\"MACHDEP\")",
1412414124
"target": "mypy.options.Options.__init__"
1412514125
},
@@ -24737,7 +24737,7 @@
2473724737
"code": "no-any-explicit",
2473824738
"column": 4,
2473924739
"message": "Explicit \"Any\" is not allowed",
24740-
"offset": 749,
24740+
"offset": 747,
2474124741
"src": "def report(*args: Any) -> None:",
2474224742
"target": "mypy.subtypes.unify_generic_callable"
2474324743
},

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Narrow type on initial assignment (#547)
77
- Annotations in function bodies are not analyzed as evaluated (#564)
88
- Invalid `cast`s show an error (#573)
9+
- Argument names are validated for subtypes (#562)
910
### Enhancements
1011
- Show 'narrowed from' in `reveal_type` (#550)
1112
- `--color-output` is enabled by default (#531)

docs/source/based_features.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,15 @@ Basedmypy handles type annotations in function bodies as unevaluated:
296296
297297
def f():
298298
a: list[int] # no error, this annotation isn't evaluated
299+
300+
Checked Argument Names
301+
----------------------
302+
303+
.. code-block::python
304+
305+
class A:
306+
def f(self, a: int): ...
307+
308+
class B(A):
309+
@override
310+
def f(self, b: int): ... # error: Signature of "f" incompatible with supertype "A"

mypy/checker.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2351,7 +2351,10 @@ def check_override(
23512351
# Use boolean variable to clarify code.
23522352
fail = False
23532353
op_method_wider_note = False
2354-
if not is_subtype(override, original, ignore_pos_arg_names=True):
2354+
2355+
if not is_subtype(
2356+
override, original, ignore_pos_arg_names=self.options.work_not_properly_function_names
2357+
):
23552358
fail = True
23562359
elif isinstance(override, Overloaded) and self.is_forward_op_method(name):
23572360
# Operator method overrides cannot extend the domain, as
@@ -2837,7 +2840,7 @@ class C(B, A[int]): ... # this is unsafe because...
28372840
call = find_member("__call__", first_type, first_type, is_operator=True)
28382841
if call and isinstance(second_type, FunctionLike):
28392842
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
2840-
ok = is_subtype(call, second_sig, ignore_pos_arg_names=True)
2843+
ok = is_subtype(call, second_sig)
28412844
elif isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike):
28422845
if first_type.is_type_obj() and second_type.is_type_obj():
28432846
# For class objects only check the subtype relationship of the classes,
@@ -2850,7 +2853,7 @@ class C(B, A[int]): ... # this is unsafe because...
28502853
# First bind/map method types when necessary.
28512854
first_sig = self.bind_and_map_method(first, first_type, ctx, base1)
28522855
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
2853-
ok = is_subtype(first_sig, second_sig, ignore_pos_arg_names=True)
2856+
ok = is_subtype(first_sig, second_sig)
28542857
elif first_type and second_type:
28552858
if isinstance(first.node, Var):
28562859
first_type = expand_self_type(first.node, first_type, fill_typevars(ctx))
@@ -7754,12 +7757,7 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool:
77547757

77557758
def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool:
77567759
return is_callable_compatible(
7757-
t,
7758-
s,
7759-
is_compat=is_same_type,
7760-
ignore_return=True,
7761-
check_args_covariantly=True,
7762-
ignore_pos_arg_names=True,
7760+
t, s, is_compat=is_same_type, ignore_return=True, check_args_covariantly=True
77637761
)
77647762

77657763

mypy/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,10 @@ def add_invertible_flag(
11391139
)
11401140
# This undocumented feature exports limited line-level dependency information.
11411141
internals_group.add_argument("--export-ref-info", action="store_true", help=argparse.SUPPRESS)
1142+
# This undocumented feature makes callable subtypes with incorrect names count as valid (needed to check mypy itself)
1143+
internals_group.add_argument(
1144+
"--work-not-properly-function-names", action="store_true", help=argparse.SUPPRESS
1145+
)
11421146

11431147
report_group = parser.add_argument_group(
11441148
title="Report generation", description="Generate a report in the specified format."

mypy/meet.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,7 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
467467

468468
if isinstance(left, CallableType) and isinstance(right, CallableType):
469469
return is_callable_compatible(
470-
left,
471-
right,
472-
is_compat=_is_overlapping_types,
473-
ignore_pos_arg_names=True,
474-
allow_partial_overlap=True,
470+
left, right, is_compat=_is_overlapping_types, allow_partial_overlap=True
475471
)
476472
elif isinstance(left, CallableType):
477473
left = left.fallback

mypy/messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3026,7 +3026,7 @@ def get_conflict_protocol_types(
30263026
subtype = mypy.typeops.get_protocol_member(left, member, class_obj)
30273027
if not subtype:
30283028
continue
3029-
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True, options=options)
3029+
is_compat = is_subtype(subtype, supertype, options=options)
30303030
if IS_SETTABLE in get_member_flags(member, right):
30313031
is_compat = is_compat and is_subtype(supertype, subtype, options=options)
30323032
if not is_compat:

mypy/options.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ class Options:
105105
"""Options collected from flags."""
106106

107107
def __init__(self) -> None:
108+
# stupid thing to make mypy project check properly
109+
self.work_not_properly_function_names = False
110+
108111
# Cache for clone_for_module()
109112
self._per_module_cache: dict[str, Options] | None = None
110113
# Despite the warnings about _per_module_cache being slow, this one might be good

mypy/subtypes.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,6 @@ def f(self) -> A: ...
10811081
for member in right.type.protocol_members:
10821082
if member in members_not_to_check:
10831083
continue
1084-
ignore_names = member != "__call__" # __call__ can be passed kwargs
10851084
# The third argument below indicates to what self type is bound.
10861085
# We always bind self to the subtype. (Similarly to nominal types).
10871086
supertype = get_proper_type(find_member(member, right, left))
@@ -1103,9 +1102,7 @@ def f(self) -> A: ...
11031102
# Nominal check currently ignores arg names
11041103
# NOTE: If we ever change this, be sure to also change the call to
11051104
# SubtypeVisitor.build_subtype_kind(...) down below.
1106-
is_compat = is_subtype(
1107-
subtype, supertype, ignore_pos_arg_names=ignore_names, options=options
1108-
)
1105+
is_compat = is_subtype(subtype, supertype, options=options)
11091106
else:
11101107
is_compat = is_proper_subtype(subtype, supertype)
11111108
if not is_compat:
@@ -1141,6 +1138,7 @@ def f(self) -> A: ...
11411138
if not proper_subtype:
11421139
# Nominal check currently ignores arg names, but __call__ is special for protocols
11431140
ignore_names = right.type.protocol_members != ["__call__"]
1141+
ignore_names = False
11441142
else:
11451143
ignore_names = False
11461144
subtype_kind = SubtypeVisitor.build_subtype_kind(

mypy/typeshed/stdlib/_weakrefset.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ class WeakSet(MutableSet[_T], Generic[_T]):
1616
def __init__(self, data: None = None) -> None: ...
1717
@overload
1818
def __init__(self, data: Iterable[_T]) -> None: ...
19-
def add(self, item: _T) -> None: ...
20-
def discard(self, item: _T) -> None: ...
19+
def add(self, item: _T) -> None: ... # type: ignore[override]
20+
def discard(self, item: _T) -> None: ... # type: ignore[override]
2121
def copy(self) -> Self: ...
22-
def remove(self, item: _T) -> None: ...
22+
def remove(self, item: _T) -> None: ... # type: ignore[override]
2323
def update(self, other: Iterable[_T]) -> None: ...
2424
def __contains__(self, item: object) -> bool: ...
2525
def __len__(self) -> int: ...

mypy/typeshed/stdlib/builtins.pyi

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ class str(Sequence[str]):
444444
def capitalize(self) -> str: ... # type: ignore[misc]
445445
def casefold(self) -> str: ... # type: ignore[misc]
446446
def center(self, __width: SupportsIndex, __fillchar: str = " ") -> str: ... # type: ignore[misc]
447-
def count(self, x: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ...
447+
def count(self, x: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... # type: ignore[override]
448448
def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: ...
449449
def endswith(
450450
self, __suffix: str | tuple[str, ...], __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...
@@ -457,7 +457,7 @@ class str(Sequence[str]):
457457
def find(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ...
458458
def format(self, *args: object, **kwargs: object) -> str: ...
459459
def format_map(self, map: _FormatMapMapping) -> str: ...
460-
def index(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ...
460+
def index(self, __sub: str, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...) -> int: ... # type: ignore[override]
461461
def isalnum(self) -> bool: ...
462462
def isalpha(self) -> bool: ...
463463
def isascii(self) -> bool: ...
@@ -533,7 +533,7 @@ class bytes(Sequence[int]):
533533
def __new__(cls) -> Self: ...
534534
def capitalize(self) -> bytes: ...
535535
def center(self, __width: SupportsIndex, __fillchar: bytes = b" ") -> bytes: ...
536-
def count(
536+
def count( # type: ignore[override]
537537
self, __sub: ReadableBuffer | SupportsIndex, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...
538538
) -> int: ...
539539
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: ...
@@ -556,7 +556,7 @@ class bytes(Sequence[int]):
556556
else:
557557
def hex(self) -> str: ...
558558

559-
def index(
559+
def index( # type: ignore[override]
560560
self, __sub: ReadableBuffer | SupportsIndex, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...
561561
) -> int: ...
562562
def isalnum(self) -> bool: ...
@@ -637,10 +637,10 @@ class bytearray(MutableSequence[int]):
637637
def __init__(self, __ints: Iterable[SupportsIndex] | SupportsIndex | ReadableBuffer) -> None: ...
638638
@overload
639639
def __init__(self, __string: str, encoding: str, errors: str = ...) -> None: ...
640-
def append(self, __item: SupportsIndex) -> None: ...
640+
def append(self, __item: SupportsIndex) -> None: ... # type: ignore[override]
641641
def capitalize(self) -> bytearray: ...
642642
def center(self, __width: SupportsIndex, __fillchar: bytes = b" ") -> bytearray: ...
643-
def count(
643+
def count( # type: ignore[override]
644644
self, __sub: ReadableBuffer | SupportsIndex, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...
645645
) -> int: ...
646646
def copy(self) -> bytearray: ...
@@ -656,7 +656,7 @@ class bytearray(MutableSequence[int]):
656656
else:
657657
def expandtabs(self, tabsize: int = ...) -> bytearray: ...
658658

659-
def extend(self, __iterable_of_ints: Iterable[SupportsIndex]) -> None: ...
659+
def extend(self, __iterable_of_ints: Iterable[SupportsIndex]) -> None: ... # type: ignore[override]
660660
def find(
661661
self, __sub: ReadableBuffer | SupportsIndex, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...
662662
) -> int: ...
@@ -665,10 +665,10 @@ class bytearray(MutableSequence[int]):
665665
else:
666666
def hex(self) -> str: ...
667667

668-
def index(
668+
def index( # type: ignore[override]
669669
self, __sub: ReadableBuffer | SupportsIndex, __start: SupportsIndex | None = ..., __end: SupportsIndex | None = ...
670670
) -> int: ...
671-
def insert(self, __index: SupportsIndex, __item: SupportsIndex) -> None: ...
671+
def insert(self, __index: SupportsIndex, __item: SupportsIndex) -> None: ... # type: ignore[override]
672672
def isalnum(self) -> bool: ...
673673
def isalpha(self) -> bool: ...
674674
def isascii(self) -> bool: ...
@@ -682,8 +682,8 @@ class bytearray(MutableSequence[int]):
682682
def lower(self) -> bytearray: ...
683683
def lstrip(self, __bytes: ReadableBuffer | None = None) -> bytearray: ...
684684
def partition(self, __sep: ReadableBuffer) -> tuple[bytearray, bytearray, bytearray]: ...
685-
def pop(self, __index: int = -1) -> int: ...
686-
def remove(self, __value: int) -> None: ...
685+
def pop(self, __index: int = -1) -> int: ... # type: ignore[override]
686+
def remove(self, __value: int) -> None: ... # type: ignore[override]
687687
if sys.version_info >= (3, 9):
688688
def removeprefix(self, __prefix: ReadableBuffer) -> bytearray: ...
689689
def removesuffix(self, __suffix: ReadableBuffer) -> bytearray: ...
@@ -881,8 +881,8 @@ class tuple(Sequence[_T_co], Generic[_T_co]):
881881
def __add__(self, __value: tuple[_T, ...]) -> tuple[_T_co | _T, ...]: ...
882882
def __mul__(self, __value: SupportsIndex) -> tuple[_T_co, ...]: ...
883883
def __rmul__(self, __value: SupportsIndex) -> tuple[_T_co, ...]: ...
884-
def count(self, __value: Any) -> int: ...
885-
def index(self, __value: Any, __start: SupportsIndex = 0, __stop: SupportsIndex = sys.maxsize) -> int: ...
884+
def count(self, __value: Any) -> int: ... # type: ignore[override]
885+
def index(self, __value: Any, __start: SupportsIndex = 0, __stop: SupportsIndex = sys.maxsize) -> int: ... # type: ignore[override]
886886
if sys.version_info >= (3, 9):
887887
def __class_getitem__(cls, __item: Any) -> GenericAlias: ...
888888

@@ -918,15 +918,15 @@ class list(MutableSequence[_T], Generic[_T]):
918918
@overload
919919
def __init__(self, __iterable: Iterable[_T]) -> None: ...
920920
def copy(self) -> list[_T]: ...
921-
def append(self, __object: _T) -> None: ...
922-
def extend(self, __iterable: Iterable[_T]) -> None: ...
923-
def pop(self, __index: SupportsIndex = -1) -> _T: ...
921+
def append(self, __object: _T) -> None: ... # type: ignore[override]
922+
def extend(self, __iterable: Iterable[_T]) -> None: ... # type: ignore[override]
923+
def pop(self, __index: SupportsIndex = -1) -> _T: ... # type: ignore[override]
924924
# Signature of `list.index` should be kept in line with `collections.UserList.index()`
925925
# and multiprocessing.managers.ListProxy.index()
926-
def index(self, __value: _T, __start: SupportsIndex = 0, __stop: SupportsIndex = sys.maxsize) -> int: ...
927-
def count(self, __value: _T) -> int: ...
928-
def insert(self, __index: SupportsIndex, __object: _T) -> None: ...
929-
def remove(self, __value: _T) -> None: ...
926+
def index(self, __value: _T, __start: SupportsIndex = 0, __stop: SupportsIndex = sys.maxsize) -> int: ... # type: ignore[override]
927+
def count(self, __value: _T) -> int: ... # type: ignore[override]
928+
def insert(self, __index: SupportsIndex, __object: _T) -> None: ... # type: ignore[override]
929+
def remove(self, __value: _T) -> None: ... # type: ignore[override]
930930
# Signature of `list.sort` should be kept inline with `collections.UserList.sort()`
931931
# and multiprocessing.managers.ListProxy.sort()
932932
#
@@ -1009,7 +1009,7 @@ class dict(MutableMapping[_KT, _VT], Generic[_KT, _VT]):
10091009
def get(self, __key: _KT, __default: _VT) -> _VT: ...
10101010
@overload
10111011
def get(self, __key: _KT, __default: _T) -> _VT | _T: ...
1012-
@overload
1012+
@overload # type: ignore[override]
10131013
def pop(self, __key: _KT) -> _VT: ...
10141014
@overload
10151015
def pop(self, __key: _KT, __default: _VT) -> _VT: ...
@@ -1045,17 +1045,17 @@ class set(MutableSet[_T], Generic[_T]):
10451045
def __init__(self) -> None: ...
10461046
@overload
10471047
def __init__(self, __iterable: Iterable[_T]) -> None: ...
1048-
def add(self, __element: _T) -> None: ...
1048+
def add(self, __element: _T) -> None: ... # type: ignore[override]
10491049
def copy(self) -> set[_T]: ...
10501050
def difference(self, *s: Iterable[Any]) -> set[_T]: ...
10511051
def difference_update(self, *s: Iterable[Any]) -> None: ...
1052-
def discard(self, __element: _T) -> None: ...
1052+
def discard(self, __element: _T) -> None: ... # type: ignore[override]
10531053
def intersection(self, *s: Iterable[Any]) -> set[_T]: ...
10541054
def intersection_update(self, *s: Iterable[Any]) -> None: ...
1055-
def isdisjoint(self, __s: Iterable[Any]) -> bool: ...
1055+
def isdisjoint(self, __s: Iterable[Any]) -> bool: ... # type: ignore[override]
10561056
def issubset(self, __s: Iterable[Any]) -> bool: ...
10571057
def issuperset(self, __s: Iterable[Any]) -> bool: ...
1058-
def remove(self, __element: _T) -> None: ...
1058+
def remove(self, __element: _T) -> None: ... # type: ignore[override]
10591059
def symmetric_difference(self, __s: Iterable[_T]) -> set[_T]: ...
10601060
def symmetric_difference_update(self, __s: Iterable[_T]) -> None: ...
10611061
def union(self, *s: Iterable[_S]) -> set[_T | _S]: ...
@@ -1088,7 +1088,7 @@ class frozenset(AbstractSet[_T_co], Generic[_T_co]):
10881088
def copy(self) -> frozenset[_T_co]: ...
10891089
def difference(self, *s: Iterable[object]) -> frozenset[_T_co]: ...
10901090
def intersection(self, *s: Iterable[object]) -> frozenset[_T_co]: ...
1091-
def isdisjoint(self, __s: Iterable[_T_co]) -> bool: ...
1091+
def isdisjoint(self, __s: Iterable[_T_co]) -> bool: ... # type: ignore[override]
10921092
def issubset(self, __s: Iterable[object]) -> bool: ...
10931093
def issuperset(self, __s: Iterable[object]) -> bool: ...
10941094
def symmetric_difference(self, __s: Iterable[_T_co]) -> frozenset[_T_co]: ...
@@ -1128,7 +1128,7 @@ class range(Sequence[int]):
11281128
def __init__(self, __stop: SupportsIndex) -> None: ...
11291129
@overload
11301130
def __init__(self, __start: SupportsIndex, __stop: SupportsIndex, __step: SupportsIndex = ...) -> None: ...
1131-
def count(self, __value: int) -> int: ...
1131+
def count(self, __value: int) -> int: ... # type: ignore[override]
11321132
def index(self, __value: int) -> int: ... # type: ignore[override]
11331133
def __len__(self) -> int: ...
11341134
def __eq__(self, __value: object) -> bool: ...

0 commit comments

Comments
 (0)