4
4
5
5
import itertools
6
6
from collections import defaultdict
7
- from contextlib import ExitStack , contextmanager
7
+ from contextlib import ExitStack , contextmanager , nullcontext
8
8
from typing import (
9
9
AbstractSet ,
10
10
Callable ,
23
23
cast ,
24
24
overload ,
25
25
)
26
- from typing_extensions import TypeAlias as _TypeAlias
26
+ from typing_extensions import ContextManager , TypeAlias as _TypeAlias
27
27
28
28
import mypy .checkexpr
29
29
from mypy import errorcodes as codes , join , message_registry , nodes , operators
249
249
# Maximum length of fixed tuple types inferred when narrowing from variadic tuples.
250
250
MAX_PRECISE_TUPLE_SIZE : Final = 8
251
251
252
- DeferredNodeType : _TypeAlias = Union [FuncDef , LambdaExpr , OverloadedFuncDef , Decorator ]
252
+ DeferredNodeType : _TypeAlias = Union [
253
+ FuncDef , LambdaExpr , OverloadedFuncDef , Decorator , AssignmentStmt
254
+ ]
253
255
FineGrainedDeferredNodeType : _TypeAlias = Union [FuncDef , MypyFile , OverloadedFuncDef ]
254
256
255
257
260
262
class DeferredNode (NamedTuple ):
261
263
node : DeferredNodeType
262
264
# And its TypeInfo (for semantic analysis self type handling
263
- active_typeinfo : TypeInfo | None
265
+ active_typeinfo : TypeInfo | FuncItem | None
264
266
265
267
266
268
# Same as above, but for fine-grained mode targets. Only top-level functions/methods
@@ -452,6 +454,12 @@ def __init__(
452
454
# always the next if statement to have a redundant expression
453
455
self .allow_redundant_expr = False
454
456
457
+ self .should_defer_current_node = False
458
+ """when a parent node should be defered"""
459
+ self .inferred_return_types : list [Type ] = []
460
+ self .inferred_yield_types : list [Type ] = []
461
+ self .inferred_send_types : list [Type ] = []
462
+
455
463
@property
456
464
def type_context (self ) -> list [Type | None ]:
457
465
return self .expr_checker .type_context
@@ -551,10 +559,15 @@ def check_second_pass(
551
559
# (self.pass_num, type_name, node.fullname or node.name))
552
560
done .add (node )
553
561
with ExitStack () as stack :
554
- if active_typeinfo :
562
+ cm : ContextManager [object ] = nullcontext ()
563
+ if isinstance (active_typeinfo , FuncItem ):
564
+ stack .enter_context (self .scope .push_function (active_typeinfo ))
565
+ cm = self .tscope .function_scope (active_typeinfo )
566
+ elif active_typeinfo :
555
567
stack .enter_context (self .tscope .class_scope (active_typeinfo ))
556
568
stack .enter_context (self .scope .push_class (active_typeinfo ))
557
- self .check_partial (node )
569
+ with cm :
570
+ self .check_partial (node )
558
571
return True
559
572
560
573
def check_partial (self , node : DeferredNodeType | FineGrainedDeferredNodeType ) -> None :
@@ -579,7 +592,9 @@ def check_top_level(self, node: MypyFile) -> None:
579
592
assert not self .current_node_deferred
580
593
# TODO: Handle __all__
581
594
582
- def defer_node (self , node : DeferredNodeType , enclosing_class : TypeInfo | None ) -> None :
595
+ def defer_node (
596
+ self , node : DeferredNodeType , enclosing_class : TypeInfo | FuncItem | None
597
+ ) -> None :
583
598
"""Defer a node for processing during next type-checking pass.
584
599
585
600
Args:
@@ -1577,6 +1592,12 @@ def check_func_def(
1577
1592
new_frame = self .binder .push_frame ()
1578
1593
new_frame .types [key ] = narrowed_type
1579
1594
self .binder .declarations [key ] = old_binder .declarations [key ]
1595
+ inferred_return_types = self .inferred_return_types
1596
+ self .inferred_return_types = []
1597
+ inferred_yield_types = self .inferred_yield_types
1598
+ self .inferred_yield_types = []
1599
+ inferred_send_types = self .inferred_send_types
1600
+ self .inferred_send_types = []
1580
1601
with self .scope .push_function (defn ):
1581
1602
# We suppress reachability warnings for empty generator functions
1582
1603
# (return; yield) which have a "yield" that's unreachable by definition
@@ -1591,6 +1612,37 @@ def check_func_def(
1591
1612
if _is_empty_generator_function (item ) or len (expanded ) >= 2 :
1592
1613
self .binder .suppress_unreachable_warnings ()
1593
1614
self .accept (item .body )
1615
+ if self .options .default_return :
1616
+ if not self .binder .is_unreachable ():
1617
+ self .inferred_return_types .append (NoneType ())
1618
+
1619
+ ret_type = get_proper_type (typ .ret_type )
1620
+ if (
1621
+ isinstance (ret_type , UntypedType )
1622
+ and ret_type .type_of_any == TypeOfAny .to_be_inferred
1623
+ ):
1624
+ ret_type = make_simplified_union (self .inferred_return_types )
1625
+ item .type = typ .copy_modified (ret_type = ret_type )
1626
+ if self .inferred_yield_types or self .inferred_send_types :
1627
+ yield_type = (
1628
+ make_simplified_union (self .inferred_yield_types )
1629
+ if self .inferred_yield_types
1630
+ else self .named_type ("builtins.object" )
1631
+ )
1632
+ # `Never` here isn't ideal, and neither is `object`, so we just go with the default typevar value
1633
+ send_type = (
1634
+ make_simplified_intersection (self .inferred_send_types )
1635
+ if self .inferred_send_types
1636
+ else NoneType ()
1637
+ )
1638
+ assert isinstance (item .type , CallableType )
1639
+ item .type .ret_type = Instance (
1640
+ self .lookup_typeinfo ("typing.Generator" ),
1641
+ [yield_type , send_type , item .type .ret_type ],
1642
+ )
1643
+ self .inferred_return_types = inferred_return_types
1644
+ self .inferred_yield_types = inferred_yield_types
1645
+ self .inferred_send_types = inferred_send_types
1594
1646
unreachable = self .binder .is_unreachable ()
1595
1647
if new_frame is not None :
1596
1648
self .binder .pop_frame (True , 0 )
@@ -1783,13 +1835,14 @@ def check_for_missing_annotations(self, fdef: FuncItem) -> None:
1783
1835
if not fdef .arguments or (
1784
1836
len (fdef .arguments ) == 1 and (fdef .arg_names [0 ] in ("self" , "cls" ))
1785
1837
):
1786
- self .fail (message_registry .RETURN_TYPE_EXPECTED , fdef )
1787
- if not has_return_statement (fdef ) and not fdef .is_generator :
1788
- self .note (
1789
- 'Use "-> None" if function does not return a value' ,
1790
- fdef ,
1791
- code = codes .NO_UNTYPED_DEF ,
1792
- )
1838
+ if not self .options .default_return :
1839
+ self .fail (message_registry .RETURN_TYPE_EXPECTED , fdef )
1840
+ if not has_return_statement (fdef ) and not fdef .is_generator :
1841
+ self .note (
1842
+ 'Use "-> None" if function does not return a value' ,
1843
+ fdef ,
1844
+ code = codes .NO_UNTYPED_DEF ,
1845
+ )
1793
1846
else :
1794
1847
self .fail (message_registry .FUNCTION_TYPE_EXPECTED , fdef )
1795
1848
elif isinstance (fdef .type , CallableType ):
@@ -3260,6 +3313,9 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
3260
3313
s .new_syntax ,
3261
3314
override_infer = s .unanalyzed_type is not None ,
3262
3315
)
3316
+ if self .should_defer_current_node :
3317
+ self .defer_node (s , self .scope .top_function ())
3318
+ self .should_defer_current_node = False
3263
3319
if s .is_alias_def :
3264
3320
self .check_type_alias_rvalue (s )
3265
3321
@@ -5055,6 +5111,18 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
5055
5111
if defn .is_async_generator :
5056
5112
self .fail (message_registry .RETURN_IN_ASYNC_GENERATOR , s )
5057
5113
return
5114
+ if defn .type :
5115
+ assert isinstance (defn .type , CallableType )
5116
+ proper_type = get_proper_type (defn .type .ret_type )
5117
+ infer = (
5118
+ isinstance (proper_type , UntypedType )
5119
+ and proper_type .type_of_any == TypeOfAny .to_be_inferred
5120
+ )
5121
+ else :
5122
+ infer = True
5123
+ if infer :
5124
+ self .inferred_return_types .append (typ )
5125
+ return
5058
5126
# Returning a value of type Any is always fine.
5059
5127
if isinstance (typ , AnyType ):
5060
5128
# (Unless you asked to be warned in that case, and the
0 commit comments