Skip to content

Commit c165fdf

Browse files
committed
infer return types
1 parent a8d99fe commit c165fdf

16 files changed

+206
-84
lines changed

.mypy/baseline.json

+11-19
Original file line numberDiff line numberDiff line change
@@ -1989,15 +1989,15 @@
19891989
"code": "explicit-override",
19901990
"column": 4,
19911991
"message": "Method \"type_context\" is not using @override but is overriding a method in class \"mypy.plugin.CheckerPluginInterface\"",
1992-
"offset": 456,
1992+
"offset": 464,
19931993
"src": "def type_context(self) -> list[Type | None]:",
19941994
"target": "mypy.checker"
19951995
},
19961996
{
19971997
"code": "explicit-override",
19981998
"column": 4,
19991999
"message": "Method \"visit_overloaded_func_def\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
2000-
"offset": 191,
2000+
"offset": 198,
20012001
"src": "def visit_overloaded_func_def(self, defn: OverloadedFuncDef, do_items=True) -> None:",
20022002
"target": "mypy.checker.TypeChecker.visit_overloaded_func_def"
20032003
},
@@ -2029,7 +2029,7 @@
20292029
"code": "explicit-override",
20302030
"column": 4,
20312031
"message": "Method \"visit_class_def\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
2032-
"offset": 1113,
2032+
"offset": 1151,
20332033
"src": "def visit_class_def(self, defn: ClassDef) -> None:",
20342034
"target": "mypy.checker.TypeChecker.visit_class_def"
20352035
},
@@ -2093,7 +2093,7 @@
20932093
"code": "truthy-bool",
20942094
"column": 23,
20952095
"message": "\"signature\" has type \"Type\" which does not implement __bool__ or __len__ so it could always be true in boolean context",
2096-
"offset": 169,
2096+
"offset": 172,
20972097
"src": "if signature:",
20982098
"target": "mypy.checker.TypeChecker.check_assignment"
20992099
},
@@ -2245,7 +2245,7 @@
22452245
"code": "explicit-override",
22462246
"column": 4,
22472247
"message": "Method \"visit_if_stmt\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
2248-
"offset": 111,
2248+
"offset": 123,
22492249
"src": "def visit_if_stmt(self, s: IfStmt) -> None:",
22502250
"target": "mypy.checker.TypeChecker.visit_if_stmt"
22512251
},
@@ -2639,7 +2639,7 @@
26392639
"code": "truthy-bool",
26402640
"column": 34,
26412641
"message": "\"item_name_expr\" has type \"Expression\" which does not implement __bool__ or __len__ so it could always be true in boolean context",
2642-
"offset": 424,
2642+
"offset": 430,
26432643
"src": "key_context = item_name_expr or item_arg",
26442644
"target": "mypy.checkexpr.ExpressionChecker.validate_typeddict_kwargs"
26452645
},
@@ -3055,7 +3055,7 @@
30553055
"code": "explicit-override",
30563056
"column": 4,
30573057
"message": "Method \"visit_await_expr\" is not using @override but is overriding a method in class \"mypy.visitor.ExpressionVisitor\"",
3058-
"offset": 21,
3058+
"offset": 33,
30593059
"src": "def visit_await_expr(self, e: AwaitExpr, allow_none_return: bool = False) -> Type:",
30603060
"target": "mypy.checkexpr.ExpressionChecker.visit_await_expr"
30613061
},
@@ -6859,7 +6859,7 @@
68596859
"code": "redundant-expr",
68606860
"column": 19,
68616861
"message": "Condition is always false",
6862-
"offset": 375,
6862+
"offset": 377,
68636863
"src": "if e.type is None:",
68646864
"target": "mypy.errors.Errors.render_messages"
68656865
},
@@ -32695,7 +32695,7 @@
3269532695
"code": "explicit-override",
3269632696
"column": 4,
3269732697
"message": "Method \"__repr__\" is not using @override but is overriding a method in class \"builtins.object\"",
32698-
"offset": 96,
32698+
"offset": 97,
3269932699
"src": "def __repr__(self) -> str:",
3270032700
"target": "mypy.types.Type.__repr__"
3270132701
},
@@ -33631,7 +33631,7 @@
3363133631
"code": "explicit-override",
3363233632
"column": 4,
3363333633
"message": "Method \"describe\" is not using @override but is overriding a method in class \"mypy.types.AnyType\"",
33634-
"offset": 59,
33634+
"offset": 61,
3363533635
"src": "def describe(self) -> str:",
3363633636
"target": "mypy.types.UntypedType.describe"
3363733637
},
@@ -36585,19 +36585,11 @@
3658536585
"src": "def attrgetter(name: str) -> operator.attrgetter[Any]:",
3658636586
"target": "mypy.util.attrgetter"
3658736587
},
36588-
{
36589-
"code": "no-any-expr",
36590-
"column": 11,
36591-
"message": "Expression type contains \"Any\" (has type \"attrgetter[Any]\")",
36592-
"offset": 1,
36593-
"src": "return operator.attrgetter(name)",
36594-
"target": "mypy.util.attrgetter"
36595-
},
3659636588
{
3659736589
"code": "no-any-expr",
3659836590
"column": 7,
3659936591
"message": "Expression has type \"Any\"",
36600-
"offset": 4,
36592+
"offset": 5,
3660136593
"src": "if orjson is not None:",
3660236594
"target": "mypy.util.json_dumps"
3660336595
},

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Basedmypy Changelog
22

33
## [Unreleased]
4+
### Added
5+
- infer return types and generator types
46

57
## [2.9.1]
68
### Fixed

docs/source/based_inference.rst

+18
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,21 @@ When a parameter is named `_`, it's type will be inferred as `object`:
6060
reveal_type(_) # Revealed type is "object"
6161
6262
This is to help with writing functions for callbacks where you don't care about certain parameters.
63+
64+
65+
Return Type Inferred
66+
--------------------
67+
68+
.. code-block:: python
69+
70+
def f(): # Revealed type is "() -> 1"
71+
return 1
72+
73+
Generator Type Inferred
74+
-----------------------
75+
76+
.. code-block:: python
77+
78+
def f(): # Revealed type is "() -> Generator[1, str, 2]"
79+
a: str = yield 1
80+
return 2

mypy/checker.py

+82-14
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import itertools
66
from collections import defaultdict
7-
from contextlib import ExitStack, contextmanager
7+
from contextlib import ExitStack, contextmanager, nullcontext
88
from typing import (
99
AbstractSet,
1010
Callable,
@@ -23,7 +23,7 @@
2323
cast,
2424
overload,
2525
)
26-
from typing_extensions import TypeAlias as _TypeAlias
26+
from typing_extensions import ContextManager, TypeAlias as _TypeAlias
2727

2828
import mypy.checkexpr
2929
from mypy import errorcodes as codes, join, message_registry, nodes, operators
@@ -249,7 +249,9 @@
249249
# Maximum length of fixed tuple types inferred when narrowing from variadic tuples.
250250
MAX_PRECISE_TUPLE_SIZE: Final = 8
251251

252-
DeferredNodeType: _TypeAlias = Union[FuncDef, LambdaExpr, OverloadedFuncDef, Decorator]
252+
DeferredNodeType: _TypeAlias = Union[
253+
FuncDef, LambdaExpr, OverloadedFuncDef, Decorator, AssignmentStmt
254+
]
253255
FineGrainedDeferredNodeType: _TypeAlias = Union[FuncDef, MypyFile, OverloadedFuncDef]
254256

255257

@@ -260,7 +262,7 @@
260262
class DeferredNode(NamedTuple):
261263
node: DeferredNodeType
262264
# And its TypeInfo (for semantic analysis self type handling
263-
active_typeinfo: TypeInfo | None
265+
active_typeinfo: TypeInfo | FuncItem | None
264266

265267

266268
# Same as above, but for fine-grained mode targets. Only top-level functions/methods
@@ -452,6 +454,12 @@ def __init__(
452454
# always the next if statement to have a redundant expression
453455
self.allow_redundant_expr = False
454456

457+
self.should_defer_current_node = False
458+
"""when a parent node should be defered"""
459+
self.inferred_return_types: list[Type] = []
460+
self.inferred_yield_types: list[Type] = []
461+
self.inferred_send_types: list[Type] = []
462+
455463
@property
456464
def type_context(self) -> list[Type | None]:
457465
return self.expr_checker.type_context
@@ -551,10 +559,15 @@ def check_second_pass(
551559
# (self.pass_num, type_name, node.fullname or node.name))
552560
done.add(node)
553561
with ExitStack() as stack:
554-
if active_typeinfo:
562+
cm: ContextManager[object] = nullcontext()
563+
if isinstance(active_typeinfo, FuncItem):
564+
stack.enter_context(self.scope.push_function(active_typeinfo))
565+
cm = self.tscope.function_scope(active_typeinfo)
566+
elif active_typeinfo:
555567
stack.enter_context(self.tscope.class_scope(active_typeinfo))
556568
stack.enter_context(self.scope.push_class(active_typeinfo))
557-
self.check_partial(node)
569+
with cm:
570+
self.check_partial(node)
558571
return True
559572

560573
def check_partial(self, node: DeferredNodeType | FineGrainedDeferredNodeType) -> None:
@@ -579,7 +592,9 @@ def check_top_level(self, node: MypyFile) -> None:
579592
assert not self.current_node_deferred
580593
# TODO: Handle __all__
581594

582-
def defer_node(self, node: DeferredNodeType, enclosing_class: TypeInfo | None) -> None:
595+
def defer_node(
596+
self, node: DeferredNodeType, enclosing_class: TypeInfo | FuncItem | None
597+
) -> None:
583598
"""Defer a node for processing during next type-checking pass.
584599
585600
Args:
@@ -1577,6 +1592,12 @@ def check_func_def(
15771592
new_frame = self.binder.push_frame()
15781593
new_frame.types[key] = narrowed_type
15791594
self.binder.declarations[key] = old_binder.declarations[key]
1595+
inferred_return_types = self.inferred_return_types
1596+
self.inferred_return_types = []
1597+
inferred_yield_types = self.inferred_yield_types
1598+
self.inferred_yield_types = []
1599+
inferred_send_types = self.inferred_send_types
1600+
self.inferred_send_types = []
15801601
with self.scope.push_function(defn):
15811602
# We suppress reachability warnings for empty generator functions
15821603
# (return; yield) which have a "yield" that's unreachable by definition
@@ -1591,6 +1612,37 @@ def check_func_def(
15911612
if _is_empty_generator_function(item) or len(expanded) >= 2:
15921613
self.binder.suppress_unreachable_warnings()
15931614
self.accept(item.body)
1615+
if self.options.default_return:
1616+
if not self.binder.is_unreachable():
1617+
self.inferred_return_types.append(NoneType())
1618+
1619+
ret_type = get_proper_type(typ.ret_type)
1620+
if (
1621+
isinstance(ret_type, UntypedType)
1622+
and ret_type.type_of_any == TypeOfAny.to_be_inferred
1623+
):
1624+
ret_type = make_simplified_union(self.inferred_return_types)
1625+
item.type = typ.copy_modified(ret_type=ret_type)
1626+
if self.inferred_yield_types or self.inferred_send_types:
1627+
yield_type = (
1628+
make_simplified_union(self.inferred_yield_types)
1629+
if self.inferred_yield_types
1630+
else self.named_type("builtins.object")
1631+
)
1632+
# `Never` here isn't ideal, and neither is `object`, so we just go with the default typevar value
1633+
send_type = (
1634+
make_simplified_intersection(self.inferred_send_types)
1635+
if self.inferred_send_types
1636+
else NoneType()
1637+
)
1638+
assert isinstance(item.type, CallableType)
1639+
item.type.ret_type = Instance(
1640+
self.lookup_typeinfo("typing.Generator"),
1641+
[yield_type, send_type, item.type.ret_type],
1642+
)
1643+
self.inferred_return_types = inferred_return_types
1644+
self.inferred_yield_types = inferred_yield_types
1645+
self.inferred_send_types = inferred_send_types
15941646
unreachable = self.binder.is_unreachable()
15951647
if new_frame is not None:
15961648
self.binder.pop_frame(True, 0)
@@ -1783,13 +1835,14 @@ def check_for_missing_annotations(self, fdef: FuncItem) -> None:
17831835
if not fdef.arguments or (
17841836
len(fdef.arguments) == 1 and (fdef.arg_names[0] in ("self", "cls"))
17851837
):
1786-
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
1787-
if not has_return_statement(fdef) and not fdef.is_generator:
1788-
self.note(
1789-
'Use "-> None" if function does not return a value',
1790-
fdef,
1791-
code=codes.NO_UNTYPED_DEF,
1792-
)
1838+
if not self.options.default_return:
1839+
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
1840+
if not has_return_statement(fdef) and not fdef.is_generator:
1841+
self.note(
1842+
'Use "-> None" if function does not return a value',
1843+
fdef,
1844+
code=codes.NO_UNTYPED_DEF,
1845+
)
17931846
else:
17941847
self.fail(message_registry.FUNCTION_TYPE_EXPECTED, fdef)
17951848
elif isinstance(fdef.type, CallableType):
@@ -3260,6 +3313,9 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
32603313
s.new_syntax,
32613314
override_infer=s.unanalyzed_type is not None,
32623315
)
3316+
if self.should_defer_current_node:
3317+
self.defer_node(s, self.scope.top_function())
3318+
self.should_defer_current_node = False
32633319
if s.is_alias_def:
32643320
self.check_type_alias_rvalue(s)
32653321

@@ -5055,6 +5111,18 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
50555111
if defn.is_async_generator:
50565112
self.fail(message_registry.RETURN_IN_ASYNC_GENERATOR, s)
50575113
return
5114+
if defn.type:
5115+
assert isinstance(defn.type, CallableType)
5116+
proper_type = get_proper_type(defn.type.ret_type)
5117+
infer = (
5118+
isinstance(proper_type, UntypedType)
5119+
and proper_type.type_of_any == TypeOfAny.to_be_inferred
5120+
)
5121+
else:
5122+
infer = True
5123+
if infer:
5124+
self.inferred_return_types.append(typ)
5125+
return
50585126
# Returning a value of type Any is always fine.
50595127
if isinstance(typ, AnyType):
50605128
# (Unless you asked to be warned in that case, and the

mypy/checkexpr.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
629629
fullname == p or fullname.startswith(f"{p}.")
630630
for p in self.chk.options.untyped_calls_exclude
631631
):
632-
if callee_type.implicit:
632+
if callee_type.implicit and not self.chk.options.infer_function_types:
633633
self.msg.untyped_function_call(callee_type, e)
634634
if fullname is None and member is not None:
635635
assert object_type is not None
@@ -638,7 +638,13 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
638638
fullname == p or fullname.startswith(f"{p}.")
639639
for p in self.chk.options.untyped_calls_exclude
640640
):
641-
if callee_type.implicit:
641+
proper_type = get_proper_type(callee_type.ret_type)
642+
if (
643+
isinstance(proper_type, UntypedType)
644+
and proper_type.type_of_any == TypeOfAny.to_be_inferred
645+
):
646+
self.chk.current_node_deferred = True
647+
elif callee_type.implicit and not self.chk.options.infer_function_types:
642648
self.msg.untyped_function_call(callee_type, e)
643649
elif has_untyped_type(callee_type):
644650
# Get module of the function, to get its settings
@@ -6290,22 +6296,34 @@ def not_ready_callback(self, name: str, context: Context) -> None:
62906296
def visit_yield_expr(self, e: YieldExpr) -> Type:
62916297
return_type = self.chk.return_types[-1]
62926298
expected_item_type = self.chk.get_generator_yield_type(return_type, False)
6299+
proper_type = get_proper_type(return_type)
6300+
infer = (
6301+
isinstance(proper_type, UntypedType)
6302+
and proper_type.type_of_any == TypeOfAny.to_be_inferred
6303+
)
6304+
if infer and self.type_context[-1]:
6305+
self.chk.inferred_send_types.append(self.type_context[-1])
62936306
if e.expr is None:
6294-
if (
6307+
if infer:
6308+
self.chk.inferred_yield_types.append(NoneType())
6309+
elif (
62956310
not isinstance(get_proper_type(expected_item_type), (NoneType, AnyType))
62966311
and self.chk.in_checked_function()
62976312
):
62986313
self.chk.fail(message_registry.YIELD_VALUE_EXPECTED, e)
62996314
else:
63006315
actual_item_type = self.accept(e.expr, expected_item_type)
6301-
self.chk.check_subtype(
6302-
actual_item_type,
6303-
expected_item_type,
6304-
e,
6305-
message_registry.INCOMPATIBLE_TYPES_IN_YIELD,
6306-
"actual type",
6307-
"expected type",
6308-
)
6316+
if infer:
6317+
self.chk.inferred_yield_types.append(actual_item_type)
6318+
else:
6319+
self.chk.check_subtype(
6320+
actual_item_type,
6321+
expected_item_type,
6322+
e,
6323+
message_registry.INCOMPATIBLE_TYPES_IN_YIELD,
6324+
"actual type",
6325+
"expected type",
6326+
)
63096327
return self.chk.get_generator_receive_type(return_type, False)
63106328

63116329
def visit_await_expr(self, e: AwaitExpr, allow_none_return: bool = False) -> Type:

0 commit comments

Comments
 (0)