Skip to content

Commit d48f32d

Browse files
committed
type functions
1 parent c1cbfa3 commit d48f32d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+838
-229
lines changed

.idea/watcherTasks.xml

+4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mypy/binder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import DefaultDict, Generator, Iterator, List, NamedTuple, Optional, Tuple, Union
66
from typing_extensions import TypeAlias as _TypeAlias
77

8+
import mypy.options
89
from mypy.erasetype import remove_instance_last_known_values
910
from mypy.join import join_simple
1011
from mypy.literals import Key, literal, literal_hash, subkeys
@@ -331,7 +332,8 @@ def assign_type(
331332
) -> None:
332333
# We should erase last known value in binder, because if we are using it,
333334
# it means that the target is not final, and therefore can't hold a literal.
334-
type = remove_instance_last_known_values(type)
335+
# HUUHHH?????
336+
# type = remove_instance_last_known_values(type)
335337

336338
if self.type_assignments is not None:
337339
# We are in a multiassign from union, defer the actual binding,

mypy/checker.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -3589,19 +3589,30 @@ def check_assignment(
35893589
):
35903590
lvalue.node.type = remove_instance_last_known_values(lvalue_type)
35913591

3592+
elif lvalue.node and lvalue.node.is_inferred and rvalue_type:
3593+
# for literal values
3594+
# Don't use type binder for definitions of special forms, like named tuples.
3595+
if not (isinstance(lvalue, NameExpr) and lvalue.is_special_form):
3596+
self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False)
3597+
35923598
elif index_lvalue:
35933599
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
35943600

35953601
if inferred:
35963602
type_context = self.get_variable_type_context(inferred)
35973603
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
3604+
original_rvalue_type = rvalue_type
35983605
if not (
35993606
inferred.is_final
36003607
or inferred.is_index_var
36013608
or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
36023609
):
36033610
rvalue_type = remove_instance_last_known_values(rvalue_type)
3604-
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
3611+
if self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue):
3612+
self.binder.assign_type(
3613+
lvalue, original_rvalue_type, original_rvalue_type, False
3614+
)
3615+
36053616
self.check_assignment_to_slots(lvalue)
36063617

36073618
# (type, operator) tuples for augmented assignments supported with partial types
@@ -4553,12 +4564,13 @@ def is_definition(self, s: Lvalue) -> bool:
45534564

45544565
def infer_variable_type(
45554566
self, name: Var, lvalue: Lvalue, init_type: Type, context: Context
4556-
) -> None:
4567+
) -> bool:
45574568
"""Infer the type of initialized variables from initializer type."""
4569+
valid = True
45584570
if isinstance(init_type, DeletedType):
45594571
self.msg.deleted_as_rvalue(init_type, context)
45604572
elif (
4561-
not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final)
4573+
not (valid := is_valid_inferred_type(init_type, is_lvalue_final=name.is_final))
45624574
and not self.no_partial_types
45634575
):
45644576
# We cannot use the type of the initialization expression for full type
@@ -4585,6 +4597,7 @@ def infer_variable_type(
45854597
init_type = strip_type(init_type)
45864598

45874599
self.set_inferred_type(name, lvalue, init_type)
4600+
return valid
45884601

45894602
def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool:
45904603
init_type = get_proper_type(init_type)

mypy/checkexpr.py

+131-4
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@
22

33
from __future__ import annotations
44

5+
import builtins
6+
import contextlib
57
import enum
8+
import importlib
9+
import io
610
import itertools
711
import time
812
from collections import defaultdict
913
from contextlib import contextmanager
14+
from types import GetSetDescriptorType
1015
from typing import Callable, ClassVar, Final, Iterable, Iterator, List, Optional, Sequence, cast
1116
from typing_extensions import TypeAlias as _TypeAlias, assert_never, overload
1217

18+
from basedtyping import TypeFunctionError
19+
1320
import mypy.checker
1421
import mypy.errorcodes as codes
1522
from mypy import applytype, erasetype, errorcodes, join, message_registry, nodes, operators, types
@@ -23,6 +30,7 @@
2330
)
2431
from mypy.checkstrformat import StringFormatterChecker
2532
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
33+
from mypy.errorcodes import ErrorCode
2634
from mypy.errors import ErrorInfo, ErrorWatcher, report_internal_error
2735
from mypy.expandtype import (
2836
expand_type,
@@ -205,12 +213,14 @@
205213
from mypy.typestate import type_state
206214
from mypy.typevars import fill_typevars
207215
from mypy.util import split_module_names
216+
from mypy.valuetotype import type_to_value, value_to_type
208217
from mypy.visitor import ExpressionVisitor
209218

210219
# Type of callback user for checking individual function arguments. See
211220
# check_args() below for details.
212221
ArgChecker: _TypeAlias = Callable[
213-
[Type, Type, ArgKind, Type, int, int, CallableType, Optional[Type], Context, Context], None
222+
[Type, Type, ArgKind, Type, int, int, CallableType, Optional[Type], Context, Context, bool],
223+
None,
214224
]
215225

216226
# Maximum nesting level for math union in overloads, setting this to large values
@@ -1846,12 +1856,13 @@ def check_callable_call(
18461856
fresh_ret_type = freshen_all_functions_type_vars(callee.ret_type)
18471857
freeze_all_type_vars(fresh_ret_type)
18481858
callee = callee.copy_modified(ret_type=fresh_ret_type)
1849-
18501859
if callee.is_generic():
18511860
need_refresh = any(
18521861
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
18531862
)
1863+
# IT"S HERE!
18541864
callee = freshen_function_type_vars(callee)
1865+
# IT"S HERE!
18551866
callee = self.infer_function_type_arguments_using_context(callee, context)
18561867
if need_refresh:
18571868
# Argument kinds etc. may have changed due to
@@ -1909,7 +1920,6 @@ def check_callable_call(
19091920
self.check_argument_types(
19101921
arg_types, arg_kinds, args, callee, formal_to_actual, context, object_type=object_type
19111922
)
1912-
19131923
if (
19141924
callee.is_type_obj()
19151925
and (len(arg_types) == 1)
@@ -1921,6 +1931,38 @@ def check_callable_call(
19211931
# Store the inferred callable type.
19221932
self.chk.store_type(callable_node, callee)
19231933

1934+
if callee.is_type_function:
1935+
with self.msg.filter_errors(filter_errors=True) as error_watcher:
1936+
if object_type:
1937+
self.check_arg(
1938+
caller_type=object_type,
1939+
original_caller_type=object_type,
1940+
caller_kind=ArgKind.ARG_POS,
1941+
callee_type=callee.bound_args[0],
1942+
n=0,
1943+
m=0,
1944+
callee=callee,
1945+
object_type=object_type,
1946+
context=context,
1947+
outer_context=context,
1948+
type_function=True,
1949+
)
1950+
1951+
self.check_argument_types(
1952+
arg_types,
1953+
arg_kinds,
1954+
args,
1955+
callee,
1956+
formal_to_actual,
1957+
context,
1958+
object_type=object_type,
1959+
type_function=True,
1960+
)
1961+
if not error_watcher.has_new_errors() and "." in callable_name:
1962+
ret_type = self.call_type_function(callable_name, object_type, arg_types, context)
1963+
if ret_type:
1964+
callee = callee.copy_modified(ret_type=ret_type)
1965+
19241966
if callable_name and (
19251967
(object_type is None and self.plugin.get_function_hook(callable_name))
19261968
or (object_type is not None and self.plugin.get_method_hook(callable_name))
@@ -1939,6 +1981,71 @@ def check_callable_call(
19391981
callee = callee.copy_modified(ret_type=new_ret_type)
19401982
return callee.ret_type, callee
19411983

1984+
def call_type_function(
1985+
self,
1986+
callable_name: str,
1987+
object_type: ProperType | None,
1988+
arg_types: list[ProperType],
1989+
context: Context,
1990+
) -> Type | None:
1991+
container_name, fn_name = callable_name.rsplit(".", maxsplit=1)
1992+
resolved = None
1993+
for part in container_name.split("."):
1994+
if resolved:
1995+
m = resolved.names.get(part)
1996+
else:
1997+
m = self.chk.modules.get(part)
1998+
if m:
1999+
resolved = m
2000+
is_method = not isinstance(resolved, MypyFile)
2001+
if is_method:
2002+
container = resolved.node
2003+
module_name = container.module_name
2004+
else:
2005+
container = resolved
2006+
module_name = container.fullname
2007+
2008+
all_sigs = []
2009+
object_type = object_type and [object_type] or []
2010+
for arg in object_type + arg_types:
2011+
if isinstance(arg, UnionType):
2012+
if not all_sigs:
2013+
all_sigs = [[x] for x in arg.items]
2014+
else:
2015+
from itertools import product
2016+
2017+
all_sigs = product(all_sigs, arg.items)
2018+
all_sigs = all_sigs or [object_type + arg_types]
2019+
all_rets = []
2020+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
2021+
mod = importlib.import_module(module_name)
2022+
container = getattr(mod, container.name) if is_method else mod
2023+
fn = getattr(container, fn_name)
2024+
for sig in all_sigs:
2025+
if isinstance(fn, (GetSetDescriptorType, property)):
2026+
fn = fn.__get__
2027+
args = [type_to_value(arg, self.chk) for arg in sig]
2028+
try:
2029+
return_value = fn(*args)
2030+
except RecursionError:
2031+
self.chk.fail(
2032+
"maximum recursion depth exceeded while evaluating type function",
2033+
context=context,
2034+
)
2035+
except TypeFunctionError as type_function_error:
2036+
code = type_function_error.code and ErrorCode(type_function_error.code, "", "")
2037+
self.chk.fail(type_function_error.message, code=code, context=context)
2038+
except Exception as exception:
2039+
self.chk.fail(
2040+
f"Invocation raises {type(exception).__name__}: {exception}",
2041+
context,
2042+
code=errorcodes.CALL_RAISES,
2043+
)
2044+
else:
2045+
all_rets.append(value_to_type(return_value, chk=self.chk))
2046+
2047+
return make_simplified_union(all_rets)
2048+
19422049
def can_return_none(self, type: TypeInfo, attr_name: str) -> bool:
19432050
"""Is the given attribute a method with a None-compatible return type?
19442051
@@ -2175,6 +2282,13 @@ def infer_function_type_arguments(
21752282
Return a derived callable type that has the arguments applied.
21762283
"""
21772284
if self.chk.in_checked_function():
2285+
if isinstance(callee_type.ret_type, TypeVarType):
2286+
# if the return type is constant, infer as literal
2287+
rvalue_type = [
2288+
remove_instance_last_known_values(arg) if isinstance(arg, Instance) else arg
2289+
for arg in args
2290+
]
2291+
21782292
# Disable type errors during type inference. There may be errors
21792293
# due to partial available context information at this time, but
21802294
# these errors can be safely ignored as the arguments will be
@@ -2581,6 +2695,8 @@ def check_argument_types(
25812695
context: Context,
25822696
check_arg: ArgChecker | None = None,
25832697
object_type: Type | None = None,
2698+
*,
2699+
type_function=False,
25842700
) -> None:
25852701
"""Check argument types against a callable type.
25862702
@@ -2712,6 +2828,7 @@ def check_argument_types(
27122828
object_type,
27132829
args[actual],
27142830
context,
2831+
type_function,
27152832
)
27162833

27172834
def check_arg(
@@ -2726,12 +2843,16 @@ def check_arg(
27262843
object_type: Type | None,
27272844
context: Context,
27282845
outer_context: Context,
2846+
type_function=False,
27292847
) -> None:
27302848
"""Check the type of a single argument in a call."""
27312849
caller_type = get_proper_type(caller_type)
27322850
original_caller_type = get_proper_type(original_caller_type)
27332851
callee_type = get_proper_type(callee_type)
2734-
2852+
if type_function:
2853+
# TODO: make this work at all
2854+
if not isinstance(caller_type, Instance) or not caller_type.last_known_value:
2855+
caller_type = self.named_type("builtins.object")
27352856
if isinstance(caller_type, DeletedType):
27362857
self.msg.deleted_as_rvalue(caller_type, context)
27372858
# Only non-abstract non-protocol class can be given where Type[...] is expected...
@@ -3348,6 +3469,7 @@ def check_arg(
33483469
object_type: Type | None,
33493470
context: Context,
33503471
outer_context: Context,
3472+
type_function: bool,
33513473
) -> None:
33523474
if not arg_approximate_similarity(caller_type, callee_type):
33533475
# No match -- exit early since none of the remaining work can change
@@ -3580,10 +3702,14 @@ def visit_bytes_expr(self, e: BytesExpr) -> Type:
35803702

35813703
def visit_float_expr(self, e: FloatExpr) -> Type:
35823704
"""Type check a float literal (trivial)."""
3705+
if mypy.options._based:
3706+
return self.infer_literal_expr_type(e.value, "builtins.float")
35833707
return self.named_type("builtins.float")
35843708

35853709
def visit_complex_expr(self, e: ComplexExpr) -> Type:
35863710
"""Type check a complex literal."""
3711+
if mypy.options._based:
3712+
return self.infer_literal_expr_type(e.value, "builtins.complex")
35873713
return self.named_type("builtins.complex")
35883714

35893715
def visit_ellipsis(self, e: EllipsisExpr) -> Type:
@@ -6502,6 +6628,7 @@ def narrow_type_from_binder(
65026628
known_type, restriction, prohibit_none_typevar_overlap=True
65036629
):
65046630
return None
6631+
65056632
return narrow_declared_type(known_type, restriction)
65066633
return known_type
65076634

mypy/errorcodes.py

+3
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ def __hash__(self) -> int:
315315
TYPE_CHECK_ONLY: Final[ErrorCode] = ErrorCode(
316316
"type-check-only", "Value doesn't exist at runtime", "General"
317317
)
318+
CALL_RAISES: Final[ErrorCode] = ErrorCode(
319+
"call-raises", "function call raises an error", "General"
320+
)
318321
REVEAL: Final = ErrorCode("reveal", "Reveal types at check time", "General")
319322

320323
# Syntax errors are often blocking.

0 commit comments

Comments
 (0)