Skip to content

Commit

Permalink
inference: represent callers_in_cycle with view slices of a stack
Browse files Browse the repository at this point in the history
Inspired by Tarjan's SCC, this changes the recursion representation to
use a single list instead of a linked-list + merged array of cycles.
  • Loading branch information
vtjnash committed Aug 7, 2024
1 parent b43e247 commit 94310a3
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 130 deletions.
33 changes: 22 additions & 11 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -802,10 +802,10 @@ function edge_matches_sv(interp::AbstractInterpreter, frame::AbsIntState,
# otherwise, we don't

# check in the cycle list first
# all items in here are mutual parents of all others
# all items in here are considered mutual parents of all others
if !any(p::AbsIntState->matches_sv(p, sv), callers_in_cycle(frame))
let parent = frame_parent(frame)
parent !== nothing || return false
parent === nothing && return false
(is_cached(parent) || frame_parent(parent) !== nothing) || return false
matches_sv(parent, sv) || return false
end
Expand Down Expand Up @@ -1307,7 +1307,7 @@ function semi_concrete_eval_call(interp::AbstractInterpreter,
if code !== nothing
irsv = IRInterpretationState(interp, code, mi, arginfo.argtypes, world)
if irsv !== nothing
irsv.parent = sv
assign_parentchild(irsv, sv)
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
@assert !(rt isa Conditional || rt isa MustAlias) "invalid lattice element returned from irinterp"
if !(isa(rt, Type) && hasintersect(rt, Bool))
Expand Down Expand Up @@ -1385,11 +1385,17 @@ function const_prop_call(interp::AbstractInterpreter,
add_remark!(interp, sv, "[constprop] Could not retrieve the source")
return nothing # this is probably a bad generated function (unsound), but just ignore it
end
frame.parent = sv
assign_parentchild(frame, sv)
if !typeinf(interp, frame)
add_remark!(interp, sv, "[constprop] Fresh constant inference hit a cycle")
@assert frame.frameid != 0 && frame.cycleid == frame.frameid
callstack = frame.callstack::Vector{AbsIntState}
@assert callstack[end] === frame && length(callstack) == frame.frameid
pop!(callstack)
return nothing
end
@assert frame.frameid != 0 && frame.cycleid == frame.frameid
@assert frame.parentid == sv.frameid
@assert inf_result.result !== nothing
# ConditionalSimpleArgtypes is allowed, because the only case in which it modifies
# the argtypes is when one of the argtypes is a `Conditional`, which case
Expand Down Expand Up @@ -3306,7 +3312,6 @@ end
# make as much progress on `frame` as possible (without handling cycles)
function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@assert !is_inferred(frame)
frame.dont_work_on_me = true # mark that this function is currently on the stack
W = frame.ip
ssavaluetypes = frame.ssavaluetypes
bbs = frame.cfg.blocks
Expand Down Expand Up @@ -3527,7 +3532,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
end # while currbb <= nbbs

frame.dont_work_on_me = false
nothing
end

Expand Down Expand Up @@ -3579,16 +3583,23 @@ end
# make as much progress on `frame` as possible (by handling cycles)
function typeinf_nocycle(interp::AbstractInterpreter, frame::InferenceState)
typeinf_local(interp, frame)
@assert isempty(frame.ip)
callstack = frame.callstack::Vector{AbsIntState}
frame.cycleid == length(callstack) && return true

# If the current frame is part of a cycle, solve the cycle before finishing
no_active_ips_in_callers = false
while !no_active_ips_in_callers
while true
# If the current frame is not the top part of a cycle, continue to the top of the cycle before resuming work
frame.cycleid == frame.frameid || return false
# If done, return and finalize this cycle
no_active_ips_in_callers && return true
# Otherwise, do at least one iteration over the entire current cycle
no_active_ips_in_callers = true
for caller in frame.callers_in_cycle
caller.dont_work_on_me && return false # cycle is above us on the stack
for i = reverse(frame.cycleid:length(callstack))
caller = callstack[i]::InferenceState
if !isempty(caller.ip)
# Note that `typeinf_local(interp, caller)` can potentially modify the other frames
# `frame.callers_in_cycle`, which is why making incremental progress requires the
# `frame.cycleid`, which is why making incremental progress requires the
# outer while loop.
typeinf_local(interp, caller)
no_active_ips_in_callers = false
Expand Down
149 changes: 98 additions & 51 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,10 @@ mutable struct InferenceState
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return
cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller
callers_in_cycle::Vector{InferenceState}
dont_work_on_me::Bool
parent # ::Union{Nothing,AbsIntState}
callstack #::Vector{AbsIntState}
parentid::Int
frameid::Int
cycleid::Int

#= results =#
result::InferenceResult # remember where to put the result
Expand Down Expand Up @@ -324,9 +325,7 @@ mutable struct InferenceState
pclimitations = IdSet{InferenceState}()
limitations = IdSet{InferenceState}()
cycle_backedges = Vector{Tuple{InferenceState,Int}}()
callers_in_cycle = Vector{InferenceState}()
dont_work_on_me = false
parent = nothing
callstack = AbsIntState[]

valid_worlds = WorldRange(1, get_world_counter())
bestguess = Bottom
Expand Down Expand Up @@ -354,7 +353,7 @@ mutable struct InferenceState
this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
pclimitations, limitations, cycle_backedges, callstack, 0, 0, 0,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
interp)
Expand Down Expand Up @@ -836,30 +835,6 @@ function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc)
return nothing
end

function print_callstack(sv::InferenceState)
print("=================== Callstack: ==================\n")
idx = 0
while sv !== nothing
print("[")
print(idx)
if !isa(sv.interp, NativeInterpreter)
print(", ")
print(typeof(sv.interp))
end
print("] ")
print(sv.linfo)
is_cached(sv) || print(" [uncached]")
println()
for cycle in sv.callers_in_cycle
print(' ', cycle.linfo)
println()
end
sv = sv.parent
idx += 1
end
print("================= End callstack ==================\n")
end

function narguments(sv::InferenceState, include_va::Bool=true)
nargs = Int(sv.src.nargs)
if !include_va
Expand All @@ -885,7 +860,9 @@ mutable struct IRInterpretationState
const lazyreachability::LazyCFGReachability
valid_worlds::WorldRange
const edges::Vector{Any}
parent # ::Union{Nothing,AbsIntState}
callstack #::Vector{AbsIntState}
frameid::Int
parentid::Int

function IRInterpretationState(interp::AbstractInterpreter,
method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
Expand All @@ -908,9 +885,9 @@ mutable struct IRInterpretationState
lazyreachability = LazyCFGReachability(ir)
valid_worlds = WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world)
edges = Any[]
parent = nothing
callstack = AbsIntState[]
return new(method_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazyreachability, valid_worlds, edges, parent)
ssa_refined, lazyreachability, valid_worlds, edges, callstack, 0, 0)
end
end

Expand All @@ -930,11 +907,34 @@ function IRInterpretationState(interp::AbstractInterpreter,
codeinst.min_world, codeinst.max_world)
end


# AbsIntState
# ===========

const AbsIntState = Union{InferenceState,IRInterpretationState}

function print_callstack(frame::AbsIntState)
print("=================== Callstack: ==================\n")
frames = frame.callstack::Vector{AbsIntState}
for idx = (frame.frameid == 0 ? 0 : 1):length(frames)
sv = (idx == 0 ? frame : frames[idx])
idx == frame.frameid && print("*")
print("[")
print(idx)
if sv isa InferenceState && !isa(sv.interp, NativeInterpreter)
print(", ")
print(typeof(sv.interp))
end
print("] ")
print(frame_instance(sv))
is_cached(sv) || print(" [uncached]")
sv.parentid == idx - 1 || print(" [parent=", sv.parentid, "]")
println()
@assert sv.frameid == idx
end
print("================= End callstack ==================\n")
end

frame_instance(sv::InferenceState) = sv.linfo
frame_instance(sv::IRInterpretationState) = sv.mi

Expand All @@ -945,8 +945,39 @@ function frame_module(sv::AbsIntState)
return def.module
end

frame_parent(sv::InferenceState) = sv.parent::Union{Nothing,AbsIntState}
frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState}
function frame_parent(sv::InferenceState)
sv.parentid == 0 && return nothing
callstack = sv.callstack::Vector{AbsIntState}
sv = callstack[sv.cycleid]::InferenceState
sv.parentid == 0 && return nothing
return callstack[sv.parentid]
end
frame_parent(sv::IRInterpretationState) = sv.parentid == 0 ? nothing : (sv.callstack::Vector{AbsIntState})[sv.parentid]

# add the orphan child to the parent and the parent to the child
function assign_parentchild(child::InferenceState, parent::AbsIntState)
@assert child.frameid == 0
child.callstack = callstack = parent.callstack::Vector{AbsIntState}
child.parentid = parent.frameid
push!(callstack, child)
child.cycleid = child.frameid = length(callstack)
nothing
end
function assign_parentchild(child::IRInterpretationState, parent::AbsIntState)
@assert child.frameid == 0
child.callstack = callstack = parent.callstack::Vector{AbsIntState}
child.parentid = parent.frameid
push!(callstack, child)
child.frameid = length(callstack)
nothing
end
function assign_parentchild(child::InferenceState, parent::Nothing)
@assert child.frameid == 0
callstack = child.callstack::Vector{AbsIntState}
push!(callstack, child)
child.cycleid = child.frameid = length(callstack)
nothing
end

function is_constproped(sv::InferenceState)
(;overridden_by_const) = sv.result
Expand All @@ -966,9 +997,6 @@ method_for_inference_limit_heuristics(sv::AbsIntState) = method_info(sv).method_
frame_world(sv::InferenceState) = sv.world
frame_world(sv::IRInterpretationState) = sv.world

callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle
callers_in_cycle(sv::IRInterpretationState) = ()

function is_effect_overridden(sv::AbsIntState, effect::Symbol)
if is_effect_overridden(frame_instance(sv), effect)
return true
Expand Down Expand Up @@ -1005,20 +1033,39 @@ Note that cycles may be visited in any order.
struct AbsIntStackUnwind
sv::AbsIntState
end
iterate(unw::AbsIntStackUnwind) = (unw.sv, (unw.sv, 0))
function iterate(unw::AbsIntStackUnwind, (sv, cyclei)::Tuple{AbsIntState, Int})
# iterate through the cycle before walking to the parent
callers = callers_in_cycle(sv)
if callers !== () && cyclei < length(callers)
cyclei += 1
parent = callers[cyclei]
else
cyclei = 0
parent = frame_parent(sv)
iterate(unw::AbsIntStackUnwind) = (unw.sv, length(unw.sv.callstack::Vector{AbsIntState}))
function iterate(unw::AbsIntStackUnwind, frame::Int)
frame == 0 && return nothing
return ((unw.sv.callstack::Vector{AbsIntState})[frame], frame - 1)
end

struct AbsIntCycle
frames::Vector{AbsIntState}
cycleid::Int
cycletop::Int
end
iterate(unw::AbsIntCycle) = unw.cycleid == 0 ? nothing : (unw.frames[unw.cycletop], unw.cycletop)
function iterate(unw::AbsIntCycle, frame::Int)
frame == unw.cycleid && return nothing
return (unw.frames[frame - 1], frame - 1)
end

"""
callers_in_cycle(sv::AbsIntState)
Iterate through all callers of the given `AbsIntState` in the abstract
interpretation stack (including the given `AbsIntState` itself) that are part
of the same cycle, only if it is part of a cycle with multiple frames.
"""
function callers_in_cycle(sv::InferenceState)
callstack = sv.callstack::Vector{AbsIntState}
cycletop = cycleid = sv.cycleid
while cycletop < length(callstack) && (callstack[cycletop + 1]::InferenceState).cycleid == cycleid
cycletop += 1
end
parent === nothing && return nothing
return (parent, (parent, cyclei))
return AbsIntCycle(callstack, cycletop == cycleid ? 0 : cycleid, cycletop)
end
callers_in_cycle(sv::IRInterpretationState) = AbsIntCycle(sv.callstack::Vector{AbsIntState}, 0, 0)

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, mi::MethodInstance)
Expand Down
8 changes: 7 additions & 1 deletion base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function concrete_eval_invoke(interp::AbstractInterpreter, ci::CodeInstance, arg
end
newirsv = IRInterpretationState(interp, ci, mi, argtypes, world)
if newirsv !== nothing
newirsv.parent = parent
assign_parentchild(newirsv, parent)
return ir_abstract_constant_propagation(interp, newirsv)
end
return Pair{Any,Tuple{Bool,Bool}}(nothing, (is_nothrow(effects), is_noub(effects)))
Expand Down Expand Up @@ -440,6 +440,12 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
store_backedges(frame_instance(irsv), irsv.edges)
end

if irsv.frameid != 0
callstack = irsv.callstack::Vector{AbsIntState}
@assert callstack[end] === irsv && length(callstack) == irsv.frameid
pop!(callstack)
end

return Pair{Any,Tuple{Bool,Bool}}(maybe_singleton_const(ultimate_rt), (nothrow, noub))
end

Expand Down
Loading

0 comments on commit 94310a3

Please sign in to comment.