Skip to content

infer return types #854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions .mypy/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -1989,15 +1989,15 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"type_context\" is not using @override but is overriding a method in class \"mypy.plugin.CheckerPluginInterface\"",
"offset": 456,
"offset": 464,
"src": "def type_context(self) -> list[Type | None]:",
"target": "mypy.checker"
},
{
"code": "explicit-override",
"column": 4,
"message": "Method \"visit_overloaded_func_def\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
"offset": 191,
"offset": 198,
"src": "def visit_overloaded_func_def(self, defn: OverloadedFuncDef, do_items=True) -> None:",
"target": "mypy.checker.TypeChecker.visit_overloaded_func_def"
},
Expand Down Expand Up @@ -2029,7 +2029,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"visit_class_def\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
"offset": 1113,
"offset": 1151,
"src": "def visit_class_def(self, defn: ClassDef) -> None:",
"target": "mypy.checker.TypeChecker.visit_class_def"
},
Expand Down Expand Up @@ -2093,7 +2093,7 @@
"code": "truthy-bool",
"column": 23,
"message": "\"signature\" has type \"Type\" which does not implement __bool__ or __len__ so it could always be true in boolean context",
"offset": 169,
"offset": 172,
"src": "if signature:",
"target": "mypy.checker.TypeChecker.check_assignment"
},
Expand Down Expand Up @@ -2245,7 +2245,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"visit_if_stmt\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
"offset": 111,
"offset": 123,
"src": "def visit_if_stmt(self, s: IfStmt) -> None:",
"target": "mypy.checker.TypeChecker.visit_if_stmt"
},
Expand Down Expand Up @@ -2639,7 +2639,7 @@
"code": "truthy-bool",
"column": 34,
"message": "\"item_name_expr\" has type \"Expression\" which does not implement __bool__ or __len__ so it could always be true in boolean context",
"offset": 424,
"offset": 430,
"src": "key_context = item_name_expr or item_arg",
"target": "mypy.checkexpr.ExpressionChecker.validate_typeddict_kwargs"
},
Expand Down Expand Up @@ -3055,7 +3055,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"visit_await_expr\" is not using @override but is overriding a method in class \"mypy.visitor.ExpressionVisitor\"",
"offset": 21,
"offset": 33,
"src": "def visit_await_expr(self, e: AwaitExpr, allow_none_return: bool = False) -> Type:",
"target": "mypy.checkexpr.ExpressionChecker.visit_await_expr"
},
Expand Down Expand Up @@ -6859,7 +6859,7 @@
"code": "redundant-expr",
"column": 19,
"message": "Condition is always false",
"offset": 375,
"offset": 377,
"src": "if e.type is None:",
"target": "mypy.errors.Errors.render_messages"
},
Expand Down Expand Up @@ -32695,7 +32695,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"__repr__\" is not using @override but is overriding a method in class \"builtins.object\"",
"offset": 96,
"offset": 97,
"src": "def __repr__(self) -> str:",
"target": "mypy.types.Type.__repr__"
},
Expand Down Expand Up @@ -33631,7 +33631,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"describe\" is not using @override but is overriding a method in class \"mypy.types.AnyType\"",
"offset": 59,
"offset": 61,
"src": "def describe(self) -> str:",
"target": "mypy.types.UntypedType.describe"
},
Expand Down Expand Up @@ -36585,19 +36585,11 @@
"src": "def attrgetter(name: str) -> operator.attrgetter[Any]:",
"target": "mypy.util.attrgetter"
},
{
"code": "no-any-expr",
"column": 11,
"message": "Expression type contains \"Any\" (has type \"attrgetter[Any]\")",
"offset": 1,
"src": "return operator.attrgetter(name)",
"target": "mypy.util.attrgetter"
},
{
"code": "no-any-expr",
"column": 7,
"message": "Expression has type \"Any\"",
"offset": 4,
"offset": 5,
"src": "if orjson is not None:",
"target": "mypy.util.json_dumps"
},
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Basedmypy Changelog

## [Unreleased]
### Added
- infer return types and generator types

## [2.9.1]
### Fixed
Expand Down
19 changes: 19 additions & 0 deletions docs/source/based_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,22 @@ When a parameter is named `_`, it's type will be inferred as `object`:
reveal_type(_) # Revealed type is "object"

This is to help with writing functions for callbacks where you don't care about certain parameters.


Return Type Inferred
--------------------

.. code-block:: python

def f(): # Revealed type is "() -> 1"
return 1


Generator Type Inferred
-----------------------

.. code-block:: python

def f(): # Revealed type is "() -> Generator[1, str, 2]"
a: str = yield 1
return 2
96 changes: 82 additions & 14 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import itertools
from collections import defaultdict
from contextlib import ExitStack, contextmanager
from contextlib import ExitStack, contextmanager, nullcontext
from typing import (
AbstractSet,
Callable,
Expand All @@ -23,7 +23,7 @@
cast,
overload,
)
from typing_extensions import TypeAlias as _TypeAlias
from typing_extensions import ContextManager, TypeAlias as _TypeAlias

import mypy.checkexpr
from mypy import errorcodes as codes, join, message_registry, nodes, operators
Expand Down Expand Up @@ -249,7 +249,9 @@
# Maximum length of fixed tuple types inferred when narrowing from variadic tuples.
MAX_PRECISE_TUPLE_SIZE: Final = 8

DeferredNodeType: _TypeAlias = Union[FuncDef, LambdaExpr, OverloadedFuncDef, Decorator]
DeferredNodeType: _TypeAlias = Union[
FuncDef, LambdaExpr, OverloadedFuncDef, Decorator, AssignmentStmt
]
FineGrainedDeferredNodeType: _TypeAlias = Union[FuncDef, MypyFile, OverloadedFuncDef]


Expand All @@ -260,7 +262,7 @@
class DeferredNode(NamedTuple):
node: DeferredNodeType
# And its TypeInfo (for semantic analysis self type handling
active_typeinfo: TypeInfo | None
active_typeinfo: TypeInfo | FuncItem | None


# Same as above, but for fine-grained mode targets. Only top-level functions/methods
Expand Down Expand Up @@ -452,6 +454,12 @@ def __init__(
# always the next if statement to have a redundant expression
self.allow_redundant_expr = False

self.should_defer_current_node = False
"""when a parent node should be defered"""
self.inferred_return_types: list[Type] = []
self.inferred_yield_types: list[Type] = []
self.inferred_send_types: list[Type] = []

@property
def type_context(self) -> list[Type | None]:
return self.expr_checker.type_context
Expand Down Expand Up @@ -551,10 +559,15 @@ def check_second_pass(
# (self.pass_num, type_name, node.fullname or node.name))
done.add(node)
with ExitStack() as stack:
if active_typeinfo:
cm: ContextManager[object] = nullcontext()
if isinstance(active_typeinfo, FuncItem):
stack.enter_context(self.scope.push_function(active_typeinfo))
cm = self.tscope.function_scope(active_typeinfo)
elif active_typeinfo:
stack.enter_context(self.tscope.class_scope(active_typeinfo))
stack.enter_context(self.scope.push_class(active_typeinfo))
self.check_partial(node)
with cm:
self.check_partial(node)
return True

def check_partial(self, node: DeferredNodeType | FineGrainedDeferredNodeType) -> None:
Expand All @@ -579,7 +592,9 @@ def check_top_level(self, node: MypyFile) -> None:
assert not self.current_node_deferred
# TODO: Handle __all__

def defer_node(self, node: DeferredNodeType, enclosing_class: TypeInfo | None) -> None:
def defer_node(
self, node: DeferredNodeType, enclosing_class: TypeInfo | FuncItem | None
) -> None:
"""Defer a node for processing during next type-checking pass.

Args:
Expand Down Expand Up @@ -1577,6 +1592,12 @@ def check_func_def(
new_frame = self.binder.push_frame()
new_frame.types[key] = narrowed_type
self.binder.declarations[key] = old_binder.declarations[key]
inferred_return_types = self.inferred_return_types
self.inferred_return_types = []
inferred_yield_types = self.inferred_yield_types
self.inferred_yield_types = []
inferred_send_types = self.inferred_send_types
self.inferred_send_types = []
with self.scope.push_function(defn):
# We suppress reachability warnings for empty generator functions
# (return; yield) which have a "yield" that's unreachable by definition
Expand All @@ -1591,6 +1612,37 @@ def check_func_def(
if _is_empty_generator_function(item) or len(expanded) >= 2:
self.binder.suppress_unreachable_warnings()
self.accept(item.body)
if self.options.default_return:
if not self.binder.is_unreachable():
self.inferred_return_types.append(NoneType())

ret_type = get_proper_type(typ.ret_type)
if (
isinstance(ret_type, UntypedType)
and ret_type.type_of_any == TypeOfAny.to_be_inferred
):
ret_type = make_simplified_union(self.inferred_return_types)
item.type = typ.copy_modified(ret_type=ret_type)
if self.inferred_yield_types or self.inferred_send_types:
yield_type = (
make_simplified_union(self.inferred_yield_types)
if self.inferred_yield_types
else self.named_type("builtins.object")
)
# `Never` here isn't ideal, and neither is `object`, so we just go with the default typevar value
send_type = (
make_simplified_intersection(self.inferred_send_types)
if self.inferred_send_types
else NoneType()
)
assert isinstance(item.type, CallableType)
item.type.ret_type = Instance(
self.lookup_typeinfo("typing.Generator"),
[yield_type, send_type, item.type.ret_type],
)
self.inferred_return_types = inferred_return_types
self.inferred_yield_types = inferred_yield_types
self.inferred_send_types = inferred_send_types
unreachable = self.binder.is_unreachable()
if new_frame is not None:
self.binder.pop_frame(True, 0)
Expand Down Expand Up @@ -1783,13 +1835,14 @@ def check_for_missing_annotations(self, fdef: FuncItem) -> None:
if not fdef.arguments or (
len(fdef.arguments) == 1 and (fdef.arg_names[0] in ("self", "cls"))
):
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
if not has_return_statement(fdef) and not fdef.is_generator:
self.note(
'Use "-> None" if function does not return a value',
fdef,
code=codes.NO_UNTYPED_DEF,
)
if not self.options.default_return:
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
if not has_return_statement(fdef) and not fdef.is_generator:
self.note(
'Use "-> None" if function does not return a value',
fdef,
code=codes.NO_UNTYPED_DEF,
)
else:
self.fail(message_registry.FUNCTION_TYPE_EXPECTED, fdef)
elif isinstance(fdef.type, CallableType):
Expand Down Expand Up @@ -3260,6 +3313,9 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
s.new_syntax,
override_infer=s.unanalyzed_type is not None,
)
if self.should_defer_current_node:
self.defer_node(s, self.scope.top_function())
self.should_defer_current_node = False
if s.is_alias_def:
self.check_type_alias_rvalue(s)

Expand Down Expand Up @@ -5055,6 +5111,18 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
if defn.is_async_generator:
self.fail(message_registry.RETURN_IN_ASYNC_GENERATOR, s)
return
if defn.type:
assert isinstance(defn.type, CallableType)
proper_type = get_proper_type(defn.type.ret_type)
infer = (
isinstance(proper_type, UntypedType)
and proper_type.type_of_any == TypeOfAny.to_be_inferred
)
else:
infer = True
if infer:
self.inferred_return_types.append(typ)
return
# Returning a value of type Any is always fine.
if isinstance(typ, AnyType):
# (Unless you asked to be warned in that case, and the
Expand Down
40 changes: 29 additions & 11 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
fullname == p or fullname.startswith(f"{p}.")
for p in self.chk.options.untyped_calls_exclude
):
if callee_type.implicit:
if callee_type.implicit and not self.chk.options.infer_function_types:
self.msg.untyped_function_call(callee_type, e)
if fullname is None and member is not None:
assert object_type is not None
Expand All @@ -638,7 +638,13 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
fullname == p or fullname.startswith(f"{p}.")
for p in self.chk.options.untyped_calls_exclude
):
if callee_type.implicit:
proper_type = get_proper_type(callee_type.ret_type)
if (
isinstance(proper_type, UntypedType)
and proper_type.type_of_any == TypeOfAny.to_be_inferred
):
self.chk.current_node_deferred = True
elif callee_type.implicit and not self.chk.options.infer_function_types:
self.msg.untyped_function_call(callee_type, e)
elif has_untyped_type(callee_type):
# Get module of the function, to get its settings
Expand Down Expand Up @@ -6290,22 +6296,34 @@ def not_ready_callback(self, name: str, context: Context) -> None:
def visit_yield_expr(self, e: YieldExpr) -> Type:
return_type = self.chk.return_types[-1]
expected_item_type = self.chk.get_generator_yield_type(return_type, False)
proper_type = get_proper_type(return_type)
infer = (
isinstance(proper_type, UntypedType)
and proper_type.type_of_any == TypeOfAny.to_be_inferred
)
if infer and self.type_context[-1]:
self.chk.inferred_send_types.append(self.type_context[-1])
if e.expr is None:
if (
if infer:
self.chk.inferred_yield_types.append(NoneType())
elif (
not isinstance(get_proper_type(expected_item_type), (NoneType, AnyType))
and self.chk.in_checked_function()
):
self.chk.fail(message_registry.YIELD_VALUE_EXPECTED, e)
else:
actual_item_type = self.accept(e.expr, expected_item_type)
self.chk.check_subtype(
actual_item_type,
expected_item_type,
e,
message_registry.INCOMPATIBLE_TYPES_IN_YIELD,
"actual type",
"expected type",
)
if infer:
self.chk.inferred_yield_types.append(actual_item_type)
else:
self.chk.check_subtype(
actual_item_type,
expected_item_type,
e,
message_registry.INCOMPATIBLE_TYPES_IN_YIELD,
"actual type",
"expected type",
)
return self.chk.get_generator_receive_type(return_type, False)

def visit_await_expr(self, e: AwaitExpr, allow_none_return: bool = False) -> Type:
Expand Down
Loading
Loading