Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: represent callers_in_cycle with view slices of a stack instead of separate lists #55364

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,10 @@ function edge_matches_sv(interp::AbstractInterpreter, frame::AbsIntState,
# otherwise, we don't

# check in the cycle list first
# all items in here are mutual parents of all others
# all items in here are considered mutual parents of all others
if !any(p::AbsIntState->matches_sv(p, sv), callers_in_cycle(frame))
let parent = frame_parent(frame)
parent !== nothing || return false
parent === nothing && return false
(is_cached(parent) || frame_parent(parent) !== nothing) || return false
matches_sv(parent, sv) || return false
end
Expand Down Expand Up @@ -1298,7 +1298,7 @@ function semi_concrete_eval_call(interp::AbstractInterpreter,
if code !== nothing
irsv = IRInterpretationState(interp, code, mi, arginfo.argtypes, world)
if irsv !== nothing
irsv.parent = sv
assign_parentchild!(irsv, sv)
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
@assert !(rt isa Conditional || rt isa MustAlias) "invalid lattice element returned from irinterp"
if !(isa(rt, Type) && hasintersect(rt, Bool))
Expand Down Expand Up @@ -1376,11 +1376,17 @@ function const_prop_call(interp::AbstractInterpreter,
add_remark!(interp, sv, "[constprop] Could not retrieve the source")
return nothing # this is probably a bad generated function (unsound), but just ignore it
end
frame.parent = sv
assign_parentchild!(frame, sv)
if !typeinf(interp, frame)
add_remark!(interp, sv, "[constprop] Fresh constant inference hit a cycle")
@assert frame.frameid != 0 && frame.cycleid == frame.frameid
callstack = frame.callstack::Vector{AbsIntState}
@assert callstack[end] === frame && length(callstack) == frame.frameid
pop!(callstack)
return nothing
end
@assert frame.frameid != 0 && frame.cycleid == frame.frameid
@assert frame.parentid == sv.frameid
@assert inf_result.result !== nothing
# ConditionalSimpleArgtypes is allowed, because the only case in which it modifies
# the argtypes is when one of the argtypes is a `Conditional`, which case
Expand Down Expand Up @@ -3328,7 +3334,6 @@ end
# make as much progress on `frame` as possible (without handling cycles)
function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@assert !is_inferred(frame)
frame.dont_work_on_me = true # mark that this function is currently on the stack
W = frame.ip
ssavaluetypes = frame.ssavaluetypes
bbs = frame.cfg.blocks
Expand Down Expand Up @@ -3549,7 +3554,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
end # while currbb <= nbbs

frame.dont_work_on_me = false
nothing
end

Expand Down Expand Up @@ -3601,16 +3605,23 @@ end
# make as much progress on `frame` as possible (by handling cycles)
function typeinf_nocycle(interp::AbstractInterpreter, frame::InferenceState)
typeinf_local(interp, frame)
@assert isempty(frame.ip)
callstack = frame.callstack::Vector{AbsIntState}
frame.cycleid == length(callstack) && return true

# If the current frame is part of a cycle, solve the cycle before finishing
no_active_ips_in_callers = false
while !no_active_ips_in_callers
while true
# If the current frame is not the top part of a cycle, continue to the top of the cycle before resuming work
frame.cycleid == frame.frameid || return false
# If done, return and finalize this cycle
no_active_ips_in_callers && return true
# Otherwise, do at least one iteration over the entire current cycle
no_active_ips_in_callers = true
for caller in frame.callers_in_cycle
caller.dont_work_on_me && return false # cycle is above us on the stack
for i = reverse(frame.cycleid:length(callstack))
caller = callstack[i]::InferenceState
if !isempty(caller.ip)
# Note that `typeinf_local(interp, caller)` can potentially modify the other frames
# `frame.callers_in_cycle`, which is why making incremental progress requires the
# `frame.cycleid`, which is why making incremental progress requires the
# outer while loop.
typeinf_local(interp, caller)
no_active_ips_in_callers = false
Expand Down
164 changes: 106 additions & 58 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ to enable flow-sensitive analysis.
"""
const VarTable = Vector{VarState}

const CACHE_MODE_NULL = 0x00 # not cached, without optimization
const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization allowed
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed
const CACHE_MODE_NULL = 0x00 # not cached, optimization optional
const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization required
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization required
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization required

mutable struct TryCatchFrame
exct
Expand Down Expand Up @@ -254,9 +254,12 @@ mutable struct InferenceState
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return
cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller
callers_in_cycle::Vector{InferenceState}
dont_work_on_me::Bool
parent # ::Union{Nothing,AbsIntState}

# IPO tracking of in-process work, shared with all frames given AbstractInterpreter
callstack #::Vector{AbsIntState}
parentid::Int # index into callstack of the parent frame that originally added this frame (call frame_parent to extract the current parent of the SCC)
frameid::Int # index into callstack at which this object is found (or zero, if this is not a cached frame and has no parent)
cycleid::Int # index into the callstack of the topmost frame in the cycle (all frames in the same cycle share the same cycleid)

#= results =#
result::InferenceResult # remember where to put the result
Expand Down Expand Up @@ -324,9 +327,7 @@ mutable struct InferenceState
pclimitations = IdSet{InferenceState}()
limitations = IdSet{InferenceState}()
cycle_backedges = Vector{Tuple{InferenceState,Int}}()
callers_in_cycle = Vector{InferenceState}()
dont_work_on_me = false
parent = nothing
callstack = AbsIntState[]

valid_worlds = WorldRange(1, get_world_counter())
bestguess = Bottom
Expand All @@ -347,17 +348,23 @@ mutable struct InferenceState

restrict_abstract_call_sites = isa(def, Module)

# some more setups
!iszero(cache_mode & CACHE_MODE_LOCAL) && push!(get_inference_cache(interp), result)

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
pclimitations, limitations, cycle_backedges, callstack, 0, 0, 0,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
interp)

# some more setups
if !iszero(cache_mode & CACHE_MODE_LOCAL)
push!(get_inference_cache(interp), result)
end
if !iszero(cache_mode & CACHE_MODE_GLOBAL)
push!(callstack, this)
this.cycleid = this.frameid = length(callstack)
end

# Apply generated function restrictions
if src.min_world != 1 || src.max_world != typemax(UInt)
# From generated functions
Expand Down Expand Up @@ -769,30 +776,6 @@ function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc)
return nothing
end

function print_callstack(sv::InferenceState)
print("=================== Callstack: ==================\n")
idx = 0
while sv !== nothing
print("[")
print(idx)
if !isa(sv.interp, NativeInterpreter)
print(", ")
print(typeof(sv.interp))
end
print("] ")
print(sv.linfo)
is_cached(sv) || print(" [uncached]")
println()
for cycle in sv.callers_in_cycle
print(' ', cycle.linfo)
println()
end
sv = sv.parent
idx += 1
end
print("================= End callstack ==================\n")
end

function narguments(sv::InferenceState, include_va::Bool=true)
nargs = Int(sv.src.nargs)
if !include_va
Expand All @@ -818,7 +801,9 @@ mutable struct IRInterpretationState
const lazyreachability::LazyCFGReachability
valid_worlds::WorldRange
const edges::Vector{Any}
parent # ::Union{Nothing,AbsIntState}
callstack #::Vector{AbsIntState}
frameid::Int
parentid::Int

function IRInterpretationState(interp::AbstractInterpreter,
method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
Expand All @@ -841,9 +826,9 @@ mutable struct IRInterpretationState
lazyreachability = LazyCFGReachability(ir)
valid_worlds = WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world)
edges = Any[]
parent = nothing
callstack = AbsIntState[]
return new(method_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazyreachability, valid_worlds, edges, parent)
ssa_refined, lazyreachability, valid_worlds, edges, callstack, 0, 0)
end
end

Expand All @@ -863,11 +848,34 @@ function IRInterpretationState(interp::AbstractInterpreter,
codeinst.min_world, codeinst.max_world)
end


# AbsIntState
# ===========

const AbsIntState = Union{InferenceState,IRInterpretationState}

function print_callstack(frame::AbsIntState)
print("=================== Callstack: ==================\n")
frames = frame.callstack::Vector{AbsIntState}
for idx = (frame.frameid == 0 ? 0 : 1):length(frames)
sv = (idx == 0 ? frame : frames[idx])
idx == frame.frameid && print("*")
print("[")
print(idx)
if sv isa InferenceState && !isa(sv.interp, NativeInterpreter)
print(", ")
print(typeof(sv.interp))
end
print("] ")
print(frame_instance(sv))
is_cached(sv) || print(" [uncached]")
sv.parentid == idx - 1 || print(" [parent=", sv.parentid, "]")
println()
@assert sv.frameid == idx
end
print("================= End callstack ==================\n")
end

frame_instance(sv::InferenceState) = sv.linfo
frame_instance(sv::IRInterpretationState) = sv.mi

Expand All @@ -878,8 +886,32 @@ function frame_module(sv::AbsIntState)
return def.module
end

frame_parent(sv::InferenceState) = sv.parent::Union{Nothing,AbsIntState}
frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState}
function frame_parent(sv::InferenceState)
sv.parentid == 0 && return nothing
callstack = sv.callstack::Vector{AbsIntState}
sv = callstack[sv.cycleid]::InferenceState
sv.parentid == 0 && return nothing
return callstack[sv.parentid]
end
frame_parent(sv::IRInterpretationState) = sv.parentid == 0 ? nothing : (sv.callstack::Vector{AbsIntState})[sv.parentid]

# add the orphan child to the parent and the parent to the child
function assign_parentchild!(child::InferenceState, parent::AbsIntState)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be safer to create a new constructor InferenceState(..., parent::AbsIntState) that internally calls assign_parentchild! and use it in IPO inference situations, rather than using assign_parentchild! individually in each case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eventually we may like to move this state into the AbstractInterpreter, since it doesn't properly belong to this InferenceState frame, but rather to the process as a whole. But some of it depends on whether and how we want to parallelize this process. For now, this was a more minimal change to eventually permit pause-able inference on a single thread.

@assert child.frameid in (0, 1)
child.callstack = callstack = parent.callstack::Vector{AbsIntState}
child.parentid = parent.frameid
push!(callstack, child)
child.cycleid = child.frameid = length(callstack)
nothing
end
function assign_parentchild!(child::IRInterpretationState, parent::AbsIntState)
@assert child.frameid in (0, 1)
child.callstack = callstack = parent.callstack::Vector{AbsIntState}
child.parentid = parent.frameid
push!(callstack, child)
child.frameid = length(callstack)
nothing
end

function is_constproped(sv::InferenceState)
(;overridden_by_const) = sv.result
Expand All @@ -899,9 +931,6 @@ method_for_inference_limit_heuristics(sv::AbsIntState) = method_info(sv).method_
frame_world(sv::InferenceState) = sv.world
frame_world(sv::IRInterpretationState) = sv.world

callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle
callers_in_cycle(sv::IRInterpretationState) = ()

function is_effect_overridden(sv::AbsIntState, effect::Symbol)
if is_effect_overridden(frame_instance(sv), effect)
return true
Expand Down Expand Up @@ -938,20 +967,39 @@ Note that cycles may be visited in any order.
struct AbsIntStackUnwind
sv::AbsIntState
end
iterate(unw::AbsIntStackUnwind) = (unw.sv, (unw.sv, 0))
function iterate(unw::AbsIntStackUnwind, (sv, cyclei)::Tuple{AbsIntState, Int})
# iterate through the cycle before walking to the parent
callers = callers_in_cycle(sv)
if callers !== () && cyclei < length(callers)
cyclei += 1
parent = callers[cyclei]
else
cyclei = 0
parent = frame_parent(sv)
iterate(unw::AbsIntStackUnwind) = (unw.sv, length(unw.sv.callstack::Vector{AbsIntState}))
function iterate(unw::AbsIntStackUnwind, frame::Int)
frame == 0 && return nothing
return ((unw.sv.callstack::Vector{AbsIntState})[frame], frame - 1)
end

struct AbsIntCycle
frames::Vector{AbsIntState}
cycleid::Int
cycletop::Int
end
iterate(unw::AbsIntCycle) = unw.cycleid == 0 ? nothing : (unw.frames[unw.cycletop], unw.cycletop)
function iterate(unw::AbsIntCycle, frame::Int)
frame == unw.cycleid && return nothing
return (unw.frames[frame - 1], frame - 1)
end

"""
callers_in_cycle(sv::AbsIntState)

Iterate through all callers of the given `AbsIntState` in the abstract
interpretation stack (including the given `AbsIntState` itself) that are part
of the same cycle, only if it is part of a cycle with multiple frames.
"""
function callers_in_cycle(sv::InferenceState)
callstack = sv.callstack::Vector{AbsIntState}
cycletop = cycleid = sv.cycleid
while cycletop < length(callstack) && (callstack[cycletop + 1]::InferenceState).cycleid == cycleid
cycletop += 1
end
parent === nothing && return nothing
return (parent, (parent, cyclei))
return AbsIntCycle(callstack, cycletop == cycleid ? 0 : cycleid, cycletop)
end
callers_in_cycle(sv::IRInterpretationState) = AbsIntCycle(sv.callstack::Vector{AbsIntState}, 0, 0)

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, mi::MethodInstance)
Expand Down
8 changes: 7 additions & 1 deletion base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function concrete_eval_invoke(interp::AbstractInterpreter, ci::CodeInstance, arg
end
newirsv = IRInterpretationState(interp, ci, mi, argtypes, world)
if newirsv !== nothing
newirsv.parent = parent
assign_parentchild!(newirsv, parent)
return ir_abstract_constant_propagation(interp, newirsv)
end
return Pair{Any,Tuple{Bool,Bool}}(nothing, (is_nothrow(effects), is_noub(effects)))
Expand Down Expand Up @@ -440,6 +440,12 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
store_backedges(frame_instance(irsv), irsv.edges)
end

if irsv.frameid != 0
callstack = irsv.callstack::Vector{AbsIntState}
@assert callstack[end] === irsv && length(callstack) == irsv.frameid
pop!(callstack)
end

return Pair{Any,Tuple{Bool,Bool}}(maybe_singleton_const(ultimate_rt), (nothrow, noub))
end

Expand Down
Loading