From 7b80b2128e5c281a55d3e025a1d01521b30ced61 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 1 Aug 2024 09:47:19 -0400 Subject: [PATCH] inference: Remove special casing for `!` (#55271) We have a handful of cases in inference where we look up functions by name (using `istopfunction`) and give them special behavior. I'd like to remove these. They're not only aesthetically ugly, but because they depend on binding lookups, rather than values, they have unclear semantics as those bindings change. They are also unsound should a user use the same name for something differnt in their own top modules (of course, it's unlikely that a user would do such a thing, but it's bad that they can't). This particular PR removes the special case for `!`, which was there to strengthen the inference result for `!` on Conditional. However, with a little bit of strengthening of the rest of the system, this can be equally well evaluated through the existing InterConditional mechanism. --- base/compiler/abstractinterpretation.jl | 46 ++++++++++++++++++++----- base/compiler/inferencestate.jl | 3 ++ base/compiler/tfuncs.jl | 11 +++++- test/compiler/inference.jl | 5 +++ 4 files changed, 56 insertions(+), 9 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 22bc1cf046d98..a7602f8cd4134 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -13,6 +13,31 @@ call_result_unused(sv::InferenceState, currpc::Int) = isexpr(sv.src.code[currpc], :call) && isempty(sv.ssavalue_uses[currpc]) call_result_unused(si::StmtInfo) = !si.used +is_const_bool_or_bottom(@nospecialize(b)) = (isa(b, Const) && isa(b.val, Bool)) || b == Bottom +function can_propagate_conditional(@nospecialize(rt), argtypes::Vector{Any}) + isa(rt, InterConditional) || return false + if rt.slot > length(argtypes) + # In the vararg tail - can't be conditional + @assert isvarargtype(argtypes[end]) + return false + end + return isa(argtypes[rt.slot], Conditional) && + is_const_bool_or_bottom(rt.thentype) && is_const_bool_or_bottom(rt.thentype) +end + +function propagate_conditional(rt::InterConditional, cond::Conditional) + new_thentype = rt.thentype === Const(false) ? cond.elsetype : cond.thentype + new_elsetype = rt.elsetype === Const(true) ? cond.thentype : cond.elsetype + if rt.thentype == Bottom + @assert rt.elsetype != Bottom + return Conditional(cond.slot, Bottom, new_elsetype) + elseif rt.elsetype == Bottom + @assert rt.thentype != Bottom + return Conditional(cond.slot, new_thentype, Bottom) + end + return Conditional(cond.slot, new_thentype, new_elsetype) +end + function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), sv::AbsIntState, max_methods::Int) @@ -156,6 +181,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end @assert !(this_conditional isa Conditional || this_rt isa MustAlias) "invalid lattice element returned from inter-procedural context" seen += 1 + + if can_propagate_conditional(this_conditional, argtypes) + # The only case where we need to keep this in rt is where + # we can directly propagate the conditional to a slot argument + # that is not one of our arguments, otherwise we keep all the + # relevant information in `conditionals` below. + this_rt = this_conditional + end + rettype = rettype โŠ”โ‚š this_rt exctype = exctype โŠ”โ‚š this_exct if has_conditional(๐•ƒโ‚š, sv) && this_conditional !== Bottom && is_lattice_bool(๐•ƒโ‚š, rettype) && fargs !== nothing @@ -409,6 +443,9 @@ function from_interconditional(๐•ƒแตข::AbstractLattice, @nospecialize(rt), sv:: has_conditional(๐•ƒแตข, sv) || return widenconditional(rt) (; fargs, argtypes) = arginfo fargs === nothing && return widenconditional(rt) + if can_propagate_conditional(rt, argtypes) + return propagate_conditional(rt, argtypes[rt.slot]::Conditional) + end slot = 0 alias = nothing thentype = elsetype = Any @@ -2217,13 +2254,6 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), end elseif is_return_type(f) return return_type_tfunc(interp, argtypes, si, sv) - elseif la == 2 && istopfunction(f, :!) - # handle Conditional propagation through !Bool - aty = argtypes[2] - if isa(aty, Conditional) - call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), si, Tuple{typeof(f), Bool}, sv, max_methods) # make sure we've inferred `!(::Bool)` - return CallMeta(Conditional(aty.slot, aty.elsetype, aty.thentype), Any, call.effects, call.info) - end elseif la == 3 && istopfunction(f, :!==) # mark !== as exactly a negated call to === call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Any, Any]), si, Tuple{typeof(f), Any, Any}, sv, max_methods) @@ -3194,7 +3224,7 @@ function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState, # narrow representation of bestguess slightly to prepare for tmerge with rt if rt isa InterConditional && bestguess isa Const slot_id = rt.slot - old_id_type = slottypes[slot_id] + old_id_type = widenconditional(slottypes[slot_id]) if bestguess.val === true && rt.elsetype !== Bottom bestguess = InterConditional(slot_id, old_id_type, Bottom) elseif bestguess.val === false && rt.thentype !== Bottom diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index c358b1177251f..38011656e41ea 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -312,6 +312,9 @@ mutable struct InferenceState nargtypes = length(argtypes) for i = 1:nslots argtyp = (i > nargtypes) ? Bottom : argtypes[i] + if argtyp === Bool && has_conditional(typeinf_lattice(interp)) + argtyp = Conditional(i, Const(true), Const(false)) + end slottypes[i] = argtyp bb_vartable1[i] = VarState(argtyp, i > nargtypes) end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 28e883d83312c..b40f65ab3ca1d 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -227,10 +227,19 @@ end @nospecs shift_tfunc(๐•ƒ::AbstractLattice, x, y) = shift_tfunc(widenlattice(๐•ƒ), x, y) @nospecs shift_tfunc(::JLTypeLattice, x, y) = widenconst(x) +function not_tfunc(๐•ƒ::AbstractLattice, @nospecialize(b)) + if isa(b, Conditional) + return Conditional(b.slot, b.elsetype, b.thentype) + elseif isa(b, Const) + return Const(not_int(b.val)) + end + return math_tfunc(๐•ƒ, b) +end + add_tfunc(and_int, 2, 2, and_int_tfunc, 1) add_tfunc(or_int, 2, 2, or_int_tfunc, 1) add_tfunc(xor_int, 2, 2, math_tfunc, 1) -add_tfunc(not_int, 1, 1, math_tfunc, 0) # usually used as not_int(::Bool) to negate a condition +add_tfunc(not_int, 1, 1, not_tfunc, 0) # usually used as not_int(::Bool) to negate a condition add_tfunc(shl_int, 2, 2, shift_tfunc, 1) add_tfunc(lshr_int, 2, 2, shift_tfunc, 1) add_tfunc(ashr_int, 2, 2, shift_tfunc, 1) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 8b6da828af54d..9ae98b884bef4 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -5866,3 +5866,8 @@ end bar54341(args...) = foo54341(4, args...) @test Core.Compiler.return_type(bar54341, Tuple{Vararg{Int}}) === Int + +# InterConditional rt with Vararg argtypes +fcondvarargs(a, b, c, d) = isa(d, Int64) +gcondvarargs(a, x...) = return fcondvarargs(a, x...) ? isa(a, Int64) : !isa(a, Int64) +@test Core.Compiler.return_type(gcondvarargs, Tuple{Vararg{Any}}) === Bool