Skip to content

Commit

Permalink
inference: refine branched Conditional types (JuliaLang#55216)
Browse files Browse the repository at this point in the history
Separated from JuliaLang#40880.
This subtle adjustment allows for more accurate type inference in the
following kind of cases:
```julia
function condition_object_update2(x)
    cond = x isa Int
    if cond # `cond` is known to be `Const(true)` within this branch
        return !cond ? nothing : x # ::Int
    else
        return  cond ? nothing : 1 # ::Int
    end
end
@test Base.infer_return_type(condition_object_update2, (Any,)) == Int
```

Also cleans up typelattice.jl a bit.
  • Loading branch information
aviatesk authored Jul 24, 2024
1 parent 0055747 commit 24cfe6e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 88 deletions.
60 changes: 40 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3067,7 +3067,7 @@ end
@inline function abstract_eval_basic_statement(interp::AbstractInterpreter,
@nospecialize(stmt), pc_vartable::VarTable, frame::InferenceState)
if isa(stmt, NewvarNode)
changes = StateUpdate(stmt.slot, VarState(Bottom, true), pc_vartable, false)
changes = StateUpdate(stmt.slot, VarState(Bottom, true), false)
return BasicStmtChange(changes, nothing, Union{})
elseif !isa(stmt, Expr)
(; rt, exct) = abstract_eval_statement(interp, stmt, pc_vartable, frame)
Expand All @@ -3082,7 +3082,7 @@ end
end
lhs = stmt.args[1]
if isa(lhs, SlotNumber)
changes = StateUpdate(lhs, VarState(rt, false), pc_vartable, false)
changes = StateUpdate(lhs, VarState(rt, false), false)
elseif isa(lhs, GlobalRef)
handle_global_assignment!(interp, frame, lhs, rt)
elseif !isa(lhs, SSAValue)
Expand All @@ -3092,7 +3092,7 @@ end
elseif hd === :method
fname = stmt.args[1]
if isa(fname, SlotNumber)
changes = StateUpdate(fname, VarState(Any, false), pc_vartable, false)
changes = StateUpdate(fname, VarState(Any, false), false)
end
return BasicStmtChange(changes, nothing, Union{})
elseif (hd === :code_coverage_effect || (
Expand Down Expand Up @@ -3242,18 +3242,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@goto branch
elseif isa(stmt, GotoIfNot)
condx = stmt.cond
condxslot = ssa_def_slot(condx, frame)
condslot = ssa_def_slot(condx, frame)
condt = abstract_eval_value(interp, condx, currstate, frame)
if condt === Bottom
ssavaluetypes[currpc] = Bottom
empty!(frame.pclimitations)
@goto find_next_bb
end
orig_condt = condt
if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condxslot, SlotNumber)
if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condslot, SlotNumber)
# if this non-`Conditional` object is a slot, we form and propagate
# the conditional constraint on it
condt = Conditional(condxslot, Const(true), Const(false))
condt = Conditional(condslot, Const(true), Const(false))
end
condval = maybe_extract_const_bool(condt)
nothrow = (condval !== nothing) || (𝕃ᵢ, orig_condt, Bool)
Expand Down Expand Up @@ -3299,21 +3299,31 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# We continue with the true branch, but process the false
# branch here.
if isa(condt, Conditional)
else_change = conditional_change(𝕃ᵢ, currstate, condt.elsetype, condt.slot)
else_change = conditional_change(𝕃ᵢ, currstate, condt, #=then_or_else=#false)
if else_change !== nothing
false_vartable = stoverwrite1!(copy(currstate), else_change)
elsestate = copy(currstate)
stoverwrite1!(elsestate, else_change)
elseif condslot isa SlotNumber
elsestate = copy(currstate)
else
false_vartable = currstate
elsestate = currstate
end
changed = update_bbstate!(𝕃ᵢ, frame, falsebb, false_vartable)
then_change = conditional_change(𝕃ᵢ, currstate, condt.thentype, condt.slot)
if condslot isa SlotNumber # refine the type of this conditional object itself for this else branch
stoverwrite1!(elsestate, condition_object_change(currstate, condt, condslot, #=then_or_else=#false))
end
else_changed = update_bbstate!(𝕃ᵢ, frame, falsebb, elsestate)
then_change = conditional_change(𝕃ᵢ, currstate, condt, #=then_or_else=#true)
thenstate = currstate
if then_change !== nothing
stoverwrite1!(currstate, then_change)
stoverwrite1!(thenstate, then_change)
end
if condslot isa SlotNumber # refine the type of this conditional object itself for this then branch
stoverwrite1!(thenstate, condition_object_change(currstate, condt, condslot, #=then_or_else=#true))
end
else
changed = update_bbstate!(𝕃ᵢ, frame, falsebb, currstate)
else_changed = update_bbstate!(𝕃ᵢ, frame, falsebb, currstate)
end
if changed
if else_changed
handle_control_backedge!(interp, frame, currpc, stmt.dest)
push!(W, falsebb)
end
Expand Down Expand Up @@ -3412,13 +3422,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
nothing
end

function conditional_change(𝕃ᵢ::AbstractLattice, state::VarTable, @nospecialize(typ), slot::Int)
vtype = state[slot]
function conditional_change(𝕃ᵢ::AbstractLattice, currstate::VarTable, condt::Conditional, then_or_else::Bool)
vtype = currstate[condt.slot]
oldtyp = vtype.typ
if iskindtype(typ)
newtyp = then_or_else ? condt.thentype : condt.elsetype
if iskindtype(newtyp)
# this code path corresponds to the special handling for `isa(x, iskindtype)` check
# implemented within `abstract_call_builtin`
elseif (𝕃ᵢ, ignorelimited(typ), ignorelimited(oldtyp))
elseif (𝕃ᵢ, ignorelimited(newtyp), ignorelimited(oldtyp))
# approximate test for `typ ∩ oldtyp` being better than `oldtyp`
# since we probably formed these types with `typesubstract`,
# the comparison is likely simple
Expand All @@ -3428,9 +3439,18 @@ function conditional_change(𝕃ᵢ::AbstractLattice, state::VarTable, @nospecia
if oldtyp isa LimitedAccuracy
# typ is better unlimited, but we may still need to compute the tmeet with the limit
# "causes" since we ignored those in the comparison
typ = tmerge(𝕃ᵢ, typ, LimitedAccuracy(Bottom, oldtyp.causes))
newtyp = tmerge(𝕃ᵢ, newtyp, LimitedAccuracy(Bottom, oldtyp.causes))
end
return StateUpdate(SlotNumber(slot), VarState(typ, vtype.undef), state, true)
return StateUpdate(SlotNumber(condt.slot), VarState(newtyp, vtype.undef), true)
end

function condition_object_change(currstate::VarTable, condt::Conditional,
condslot::SlotNumber, then_or_else::Bool)
vtype = currstate[slot_id(condslot)]
newcondt = Conditional(condt.slot,
then_or_else ? condt.thentype : Union{},
then_or_else ? Union{} : condt.elsetype)
return StateUpdate(condslot, VarState(newcondt, vtype.undef), false)
end

# make as much progress on `frame` as possible (by handling cycles)
Expand Down
23 changes: 0 additions & 23 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ end
struct StateUpdate
var::SlotNumber
vtype::VarState
state::VarTable
conditional::Bool
end

Expand Down Expand Up @@ -724,28 +723,6 @@ function invalidate_slotwrapper(vt::VarState, changeid::Int, ignore_conditional:
return nothing
end

function stupdate!(lattice::AbstractLattice, state::VarTable, changes::StateUpdate)
changed = false
changeid = slot_id(changes.var)
for i = 1:length(state)
if i == changeid
newtype = changes.vtype
else
newtype = changes.state[i]
end
invalidated = invalidate_slotwrapper(newtype, changeid, changes.conditional)
if invalidated !== nothing
newtype = invalidated
end
oldtype = state[i]
if schanged(lattice, newtype, oldtype)
state[i] = smerge(lattice, oldtype, newtype)
changed = true
end
end
return changed
end

function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable)
changed = false
for i = 1:length(state)
Expand Down
87 changes: 42 additions & 45 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2151,78 +2151,75 @@ end

@testset "branching on conditional object" begin
# simple
@test Base.return_types((Union{Nothing,Int},)) do a
@test Base.infer_return_type((Union{Nothing,Int},)) do a
b = a === nothing
return b ? 0 : a # ::Int
end == Any[Int]
end == Int

# can use multiple times (as far as the subject of condition hasn't changed)
@test Base.return_types((Union{Nothing,Int},)) do a
@test Base.infer_return_type((Union{Nothing,Int},)) do a
b = a === nothing
c = b ? 0 : a # c::Int
d = !b ? a : 0 # d::Int
return c, d # ::Tuple{Int,Int}
end == Any[Tuple{Int,Int}]
end == Tuple{Int,Int}

# should invalidate old constraint when the subject of condition has changed
@test Base.return_types((Union{Nothing,Int},)) do a
@test Base.infer_return_type((Union{Nothing,Int},)) do a
cond = a === nothing
r1 = cond ? 0 : a # r1::Int
a = 0
r2 = cond ? a : 1 # r2::Int, not r2::Union{Nothing,Int}
return r1, r2 # ::Tuple{Int,Int}
end == Any[Tuple{Int,Int}]
end == Tuple{Int,Int}
end

# https://github.com/JuliaLang/julia/issues/42090#issuecomment-911824851
# `PartialStruct` shouldn't wrap `Conditional`
let M = Module()
@eval M begin
struct BePartialStruct
val::Int
cond
end
end

rt = @eval M begin
Base.return_types((Union{Nothing,Int},)) do a
cond = a === nothing
obj = $(Expr(:new, M.BePartialStruct, 42, :cond))
r1 = getfield(obj, :cond) ? 0 : a # r1::Union{Nothing,Int}, not r1::Int (because PartialStruct doesn't wrap Conditional)
a = $(gensym(:anyvar))::Any
r2 = getfield(obj, :cond) ? a : nothing # r2::Any, not r2::Const(nothing) (we don't need to worry about constraint invalidation here)
return r1, r2 # ::Tuple{Union{Nothing,Int},Any}
end |> only
end
@test rt == Tuple{Union{Nothing,Int},Any}
struct BePartialStruct
val::Int
cond
end
@test Tuple{Union{Nothing,Int},Any} == @eval Base.infer_return_type((Union{Nothing,Int},)) do a
cond = a === nothing
obj = $(Expr(:new, BePartialStruct, 42, :cond))
r1 = getfield(obj, :cond) ? 0 : a # r1::Union{Nothing,Int}, not r1::Int (because PartialStruct doesn't wrap Conditional)
a = $(gensym(:anyvar))::Any
r2 = getfield(obj, :cond) ? a : nothing # r2::Any, not r2::Const(nothing) (we don't need to worry about constraint invalidation here)
return r1, r2 # ::Tuple{Union{Nothing,Int},Any}
end

# make sure we never form nested `Conditional` (https://github.com/JuliaLang/julia/issues/46207)
@test Base.return_types((Any,)) do a
@test Base.infer_return_type((Any,)) do a
c = isa(a, Integer)
42 === c ? :a : "b"
end |> only === String
@test Base.return_types((Any,)) do a
end == String
@test Base.infer_return_type((Any,)) do a
c = isa(a, Integer)
c === 42 ? :a : "b"
end |> only === String
end == String

@testset "conditional constraint propagation from non-`Conditional` object" begin
@test Base.return_types((Bool,)) do b
if b
return !b ? nothing : 1 # ::Int
else
return 0
end
end == Any[Int]

@test Base.return_types((Any,)) do b
if b
return b # ::Bool
else
return nothing
end
end == Any[Union{Bool,Nothing}]
function condition_object_update1(cond)
if cond # `cond` is known to be `Const(true)` within this branch
return !cond ? nothing : 1 # ::Int
else
return cond ? nothing : 1 # ::Int
end
end
function condition_object_update2(x)
cond = x isa Int
if cond # `cond` is known to be `Const(true)` within this branch
return !cond ? nothing : x # ::Int
else
return cond ? nothing : 1 # ::Int
end
end
@testset "state update for condition object" begin
# refine the type of condition object into constant boolean values on branching
@test Base.infer_return_type(condition_object_update1, (Bool,)) == Int
@test Base.infer_return_type(condition_object_update1, (Any,)) == Int
# refine even when their original type is `Conditional`
@test Base.infer_return_type(condition_object_update2, (Any,)) == Int
end

@testset "`from_interprocedural!`: translate inter-procedural information" begin
Expand Down

0 comments on commit 24cfe6e

Please sign in to comment.