2
2
3
3
from __future__ import annotations
4
4
5
+ import builtins
6
+ import contextlib
5
7
import enum
8
+ import importlib
9
+ import io
6
10
import itertools
7
11
import time
8
12
from collections import defaultdict
9
13
from contextlib import contextmanager
14
+ from types import GetSetDescriptorType
10
15
from typing import Callable , ClassVar , Final , Iterable , Iterator , List , Optional , Sequence , cast
11
16
from typing_extensions import TypeAlias as _TypeAlias , assert_never , overload
12
17
18
+ from basedtyping import TypeFunctionError
19
+
13
20
import mypy .checker
14
21
import mypy .errorcodes as codes
15
22
from mypy import applytype , erasetype , errorcodes , join , message_registry , nodes , operators , types
23
30
)
24
31
from mypy .checkstrformat import StringFormatterChecker
25
32
from mypy .erasetype import erase_type , remove_instance_last_known_values , replace_meta_vars
33
+ from mypy .errorcodes import ErrorCode
26
34
from mypy .errors import ErrorInfo , ErrorWatcher , report_internal_error
27
35
from mypy .expandtype import (
28
36
expand_type ,
205
213
from mypy .typestate import type_state
206
214
from mypy .typevars import fill_typevars
207
215
from mypy .util import split_module_names
216
+ from mypy .valuetotype import type_to_value , value_to_type
208
217
from mypy .visitor import ExpressionVisitor
209
218
210
219
# Type of callback user for checking individual function arguments. See
211
220
# check_args() below for details.
212
221
ArgChecker : _TypeAlias = Callable [
213
- [Type , Type , ArgKind , Type , int , int , CallableType , Optional [Type ], Context , Context ], None
222
+ [Type , Type , ArgKind , Type , int , int , CallableType , Optional [Type ], Context , Context , bool ],
223
+ None ,
214
224
]
215
225
216
226
# Maximum nesting level for math union in overloads, setting this to large values
@@ -1846,12 +1856,13 @@ def check_callable_call(
1846
1856
fresh_ret_type = freshen_all_functions_type_vars (callee .ret_type )
1847
1857
freeze_all_type_vars (fresh_ret_type )
1848
1858
callee = callee .copy_modified (ret_type = fresh_ret_type )
1849
-
1850
1859
if callee .is_generic ():
1851
1860
need_refresh = any (
1852
1861
isinstance (v , (ParamSpecType , TypeVarTupleType )) for v in callee .variables
1853
1862
)
1863
+ # IT"S HERE!
1854
1864
callee = freshen_function_type_vars (callee )
1865
+ # IT"S HERE!
1855
1866
callee = self .infer_function_type_arguments_using_context (callee , context )
1856
1867
if need_refresh :
1857
1868
# Argument kinds etc. may have changed due to
@@ -1909,7 +1920,6 @@ def check_callable_call(
1909
1920
self .check_argument_types (
1910
1921
arg_types , arg_kinds , args , callee , formal_to_actual , context , object_type = object_type
1911
1922
)
1912
-
1913
1923
if (
1914
1924
callee .is_type_obj ()
1915
1925
and (len (arg_types ) == 1 )
@@ -1921,6 +1931,38 @@ def check_callable_call(
1921
1931
# Store the inferred callable type.
1922
1932
self .chk .store_type (callable_node , callee )
1923
1933
1934
+ if callee .is_type_function :
1935
+ with self .msg .filter_errors (filter_errors = True ) as error_watcher :
1936
+ if object_type :
1937
+ self .check_arg (
1938
+ caller_type = object_type ,
1939
+ original_caller_type = object_type ,
1940
+ caller_kind = ArgKind .ARG_POS ,
1941
+ callee_type = callee .bound_args [0 ],
1942
+ n = 0 ,
1943
+ m = 0 ,
1944
+ callee = callee ,
1945
+ object_type = object_type ,
1946
+ context = context ,
1947
+ outer_context = context ,
1948
+ type_function = True ,
1949
+ )
1950
+
1951
+ self .check_argument_types (
1952
+ arg_types ,
1953
+ arg_kinds ,
1954
+ args ,
1955
+ callee ,
1956
+ formal_to_actual ,
1957
+ context ,
1958
+ object_type = object_type ,
1959
+ type_function = True ,
1960
+ )
1961
+ if not error_watcher .has_new_errors () and "." in callable_name :
1962
+ ret_type = self .call_type_function (callable_name , object_type , arg_types , context )
1963
+ if ret_type :
1964
+ callee = callee .copy_modified (ret_type = ret_type )
1965
+
1924
1966
if callable_name and (
1925
1967
(object_type is None and self .plugin .get_function_hook (callable_name ))
1926
1968
or (object_type is not None and self .plugin .get_method_hook (callable_name ))
@@ -1939,6 +1981,71 @@ def check_callable_call(
1939
1981
callee = callee .copy_modified (ret_type = new_ret_type )
1940
1982
return callee .ret_type , callee
1941
1983
1984
+ def call_type_function (
1985
+ self ,
1986
+ callable_name : str ,
1987
+ object_type : ProperType | None ,
1988
+ arg_types : list [ProperType ],
1989
+ context : Context ,
1990
+ ) -> Type | None :
1991
+ container_name , fn_name = callable_name .rsplit ("." , maxsplit = 1 )
1992
+ resolved = None
1993
+ for part in container_name .split ("." ):
1994
+ if resolved :
1995
+ m = resolved .names .get (part )
1996
+ else :
1997
+ m = self .chk .modules .get (part )
1998
+ if m :
1999
+ resolved = m
2000
+ is_method = not isinstance (resolved , MypyFile )
2001
+ if is_method :
2002
+ container = resolved .node
2003
+ module_name = container .module_name
2004
+ else :
2005
+ container = resolved
2006
+ module_name = container .fullname
2007
+
2008
+ all_sigs = []
2009
+ object_type = object_type and [object_type ] or []
2010
+ for arg in object_type + arg_types :
2011
+ if isinstance (arg , UnionType ):
2012
+ if not all_sigs :
2013
+ all_sigs = [[x ] for x in arg .items ]
2014
+ else :
2015
+ from itertools import product
2016
+
2017
+ all_sigs = product (all_sigs , arg .items )
2018
+ all_sigs = all_sigs or [object_type + arg_types ]
2019
+ all_rets = []
2020
+ with contextlib .redirect_stdout (io .StringIO ()), contextlib .redirect_stderr (io .StringIO ()):
2021
+ mod = importlib .import_module (module_name )
2022
+ container = getattr (mod , container .name ) if is_method else mod
2023
+ fn = getattr (container , fn_name )
2024
+ for sig in all_sigs :
2025
+ if isinstance (fn , (GetSetDescriptorType , property )):
2026
+ fn = fn .__get__
2027
+ args = [type_to_value (arg , self .chk ) for arg in sig ]
2028
+ try :
2029
+ return_value = fn (* args )
2030
+ except RecursionError :
2031
+ self .chk .fail (
2032
+ "maximum recursion depth exceeded while evaluating type function" ,
2033
+ context = context ,
2034
+ )
2035
+ except TypeFunctionError as type_function_error :
2036
+ code = type_function_error .code and ErrorCode (type_function_error .code , "" , "" )
2037
+ self .chk .fail (type_function_error .message , code = code , context = context )
2038
+ except Exception as exception :
2039
+ self .chk .fail (
2040
+ f"Invocation raises { type (exception ).__name__ } : { exception } " ,
2041
+ context ,
2042
+ code = errorcodes .CALL_RAISES ,
2043
+ )
2044
+ else :
2045
+ all_rets .append (value_to_type (return_value , chk = self .chk ))
2046
+
2047
+ return make_simplified_union (all_rets )
2048
+
1942
2049
def can_return_none (self , type : TypeInfo , attr_name : str ) -> bool :
1943
2050
"""Is the given attribute a method with a None-compatible return type?
1944
2051
@@ -2175,6 +2282,13 @@ def infer_function_type_arguments(
2175
2282
Return a derived callable type that has the arguments applied.
2176
2283
"""
2177
2284
if self .chk .in_checked_function ():
2285
+ if isinstance (callee_type .ret_type , TypeVarType ):
2286
+ # if the return type is constant, infer as literal
2287
+ rvalue_type = [
2288
+ remove_instance_last_known_values (arg ) if isinstance (arg , Instance ) else arg
2289
+ for arg in args
2290
+ ]
2291
+
2178
2292
# Disable type errors during type inference. There may be errors
2179
2293
# due to partial available context information at this time, but
2180
2294
# these errors can be safely ignored as the arguments will be
@@ -2581,6 +2695,8 @@ def check_argument_types(
2581
2695
context : Context ,
2582
2696
check_arg : ArgChecker | None = None ,
2583
2697
object_type : Type | None = None ,
2698
+ * ,
2699
+ type_function = False ,
2584
2700
) -> None :
2585
2701
"""Check argument types against a callable type.
2586
2702
@@ -2712,6 +2828,7 @@ def check_argument_types(
2712
2828
object_type ,
2713
2829
args [actual ],
2714
2830
context ,
2831
+ type_function ,
2715
2832
)
2716
2833
2717
2834
def check_arg (
@@ -2726,12 +2843,16 @@ def check_arg(
2726
2843
object_type : Type | None ,
2727
2844
context : Context ,
2728
2845
outer_context : Context ,
2846
+ type_function = False ,
2729
2847
) -> None :
2730
2848
"""Check the type of a single argument in a call."""
2731
2849
caller_type = get_proper_type (caller_type )
2732
2850
original_caller_type = get_proper_type (original_caller_type )
2733
2851
callee_type = get_proper_type (callee_type )
2734
-
2852
+ if type_function :
2853
+ # TODO: make this work at all
2854
+ if not isinstance (caller_type , Instance ) or not caller_type .last_known_value :
2855
+ caller_type = self .named_type ("builtins.object" )
2735
2856
if isinstance (caller_type , DeletedType ):
2736
2857
self .msg .deleted_as_rvalue (caller_type , context )
2737
2858
# Only non-abstract non-protocol class can be given where Type[...] is expected...
@@ -3348,6 +3469,7 @@ def check_arg(
3348
3469
object_type : Type | None ,
3349
3470
context : Context ,
3350
3471
outer_context : Context ,
3472
+ type_function : bool ,
3351
3473
) -> None :
3352
3474
if not arg_approximate_similarity (caller_type , callee_type ):
3353
3475
# No match -- exit early since none of the remaining work can change
@@ -3580,10 +3702,14 @@ def visit_bytes_expr(self, e: BytesExpr) -> Type:
3580
3702
3581
3703
def visit_float_expr (self , e : FloatExpr ) -> Type :
3582
3704
"""Type check a float literal (trivial)."""
3705
+ if mypy .options ._based :
3706
+ return self .infer_literal_expr_type (e .value , "builtins.float" )
3583
3707
return self .named_type ("builtins.float" )
3584
3708
3585
3709
def visit_complex_expr (self , e : ComplexExpr ) -> Type :
3586
3710
"""Type check a complex literal."""
3711
+ if mypy .options ._based :
3712
+ return self .infer_literal_expr_type (e .value , "builtins.complex" )
3587
3713
return self .named_type ("builtins.complex" )
3588
3714
3589
3715
def visit_ellipsis (self , e : EllipsisExpr ) -> Type :
@@ -6502,6 +6628,7 @@ def narrow_type_from_binder(
6502
6628
known_type , restriction , prohibit_none_typevar_overlap = True
6503
6629
):
6504
6630
return None
6631
+
6505
6632
return narrow_declared_type (known_type , restriction )
6506
6633
return known_type
6507
6634
0 commit comments