Skip to content

Commit 750df9f

Browse files
authored
inference: permit non-direct recursion reducers (#50696)
Fix #45759 Fix #46557 Fix #31485 Depends on #50694 due to a failing broadcast test without it (related to #50695)
1 parent 90494c2 commit 750df9f

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

base/compiler/abstractinterpretation.jl

+13-7
Original file line numberDiff line numberDiff line change
@@ -543,25 +543,30 @@ function abstract_call_method(interp::AbstractInterpreter,
543543
if topmost !== nothing
544544
msig = unwrap_unionall(method.sig)::DataType
545545
spec_len = length(msig.parameters) + 1
546-
ls = length(sigtuple.parameters)
547546
mi = frame_instance(sv)
548547

548+
if isdefined(method, :recursion_relation)
549+
# We don't require the recursion_relation to be transitive, so
550+
# apply a hard limit
551+
hardlimit = true
552+
end
553+
549554
if method === mi.def
550555
# Under direct self-recursion, permit much greater use of reducers.
551556
# here we assume that complexity(specTypes) :>= complexity(sig)
552557
comparison = mi.specTypes
553558
l_comparison = length((unwrap_unionall(comparison)::DataType).parameters)
554559
spec_len = max(spec_len, l_comparison)
560+
elseif !hardlimit && isa(topmost, InferenceState)
561+
# Without a hardlimit, permit use of reducers too.
562+
comparison = frame_instance(topmost).specTypes
563+
# n.b. currently don't allow vararg reducers
564+
#l_comparison = length((unwrap_unionall(comparison)::DataType).parameters)
565+
#spec_len = max(spec_len, l_comparison)
555566
else
556567
comparison = method.sig
557568
end
558569

559-
if isdefined(method, :recursion_relation)
560-
# We don't require the recursion_relation to be transitive, so
561-
# apply a hard limit
562-
hardlimit = true
563-
end
564-
565570
# see if the type is actually too big (relative to the caller), and limit it if required
566571
newsig = limit_type_size(sig, comparison, hardlimit ? comparison : mi.specTypes, InferenceParams(interp).tuple_complexity_limit_depth, spec_len)
567572

@@ -588,6 +593,7 @@ function abstract_call_method(interp::AbstractInterpreter,
588593
poison_callstack!(sv, parentframe === nothing ? topmost : parentframe)
589594
end
590595
end
596+
# n.b. this heuristic depends on the non-local state, so we must record the limit later
591597
sig = newsig
592598
sparams = svec()
593599
edgelimited = true

test/compiler/inference.jl

+28
Original file line numberDiff line numberDiff line change
@@ -5100,6 +5100,34 @@ end
51005100
end |> only === String
51015101
# JET.test_call(s::AbstractString->Base._string(s, 'c'))
51025102

5103+
# issue #45759 #46557
5104+
g45759(x::Tuple{Any,Vararg}) = x[1] + _g45759(x[2:end])
5105+
g45759(x::Tuple{}) = 0
5106+
_g45759(x) = g45759(x)
5107+
@test only(Base.return_types(g45759, Tuple{Tuple{Int,Int,Int,Int,Int,Int,Int}})) == Int
5108+
5109+
h45759(x::Tuple{Any,Vararg}; kwargs...) = x[1] + h45759(x[2:end]; kwargs...)
5110+
h45759(x::Tuple{}; kwargs...) = 0
5111+
@test only(Base.return_types(h45759, Tuple{Tuple{Int,Int,Int,Int,Int,Int,Int}})) == Int
5112+
5113+
@test only(Base.return_types((typeof([[[1]]]),)) do x
5114+
sum(x) do v
5115+
sum(length, v)
5116+
end
5117+
end) == Int
5118+
5119+
struct FunctionSum{Tf}
5120+
functions::Tf
5121+
end
5122+
(F::FunctionSum)(x) = sum(f -> f(x), F.functions)
5123+
F = FunctionSum((x -> sqrt(x), FunctionSum((x -> x^2, x -> x^3))))
5124+
@test @inferred(F(1.)) === 3.0
5125+
5126+
f31485(arr::AbstractArray{T, 0}) where {T} = arr
5127+
indirect31485(arr) = f31485(arr)
5128+
f31485(arr::AbstractArray{T, N}) where {T, N} = indirect31485(view(arr, 1, ntuple(i -> :, Val(N-1))...))
5129+
@test @inferred(f31485(zeros(3,3,3,3,3),)) == fill(0.0)
5130+
51035131
# override const-prop' return type with the concrete-eval result
51045132
# if concrete-eval returns non-inlineable constant
51055133
Base.@assume_effects :foldable function continue_const_prop(i, j)

0 commit comments

Comments
 (0)