Skip to content

Commit

Permalink
inference: don't allocate TryCatchFrame for `compute_trycatch(::IRC…
Browse files Browse the repository at this point in the history
…ode)` (#56835)

`TryCatchFrame` is only required for the abstract interpretation and is
not necessary in `compute_trycatch` within slot2ssa.jl.

@nanosoldier `runbenchmarks("inference", vs=":master")`
  • Loading branch information
aviatesk authored Dec 16, 2024
1 parent dafaa61 commit 0551039
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
45 changes: 32 additions & 13 deletions Compiler/src/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,29 @@ const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization required
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization required
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization required

mutable struct TryCatchFrame
abstract type Handler end
get_enter_idx(handler::Handler) = get_enter_idx_impl(handler)::Int

mutable struct TryCatchFrame <: Handler
exct
scopet
const enter_idx::Int
scope_uses::Vector{Int}
TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx)
TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) =
new(exct, scopet, enter_idx)
end
TryCatchFrame(stmt::EnterNode, pc::Int) =
TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc)
get_enter_idx_impl((; enter_idx)::TryCatchFrame) = enter_idx

struct SimpleHandler <: Handler
enter_idx::Int
end
SimpleHandler(::EnterNode, pc::Int) = SimpleHandler(pc)
get_enter_idx_impl((; enter_idx)::SimpleHandler) = enter_idx

struct HandlerInfo
handlers::Vector{TryCatchFrame}
struct HandlerInfo{T<:Handler}
handlers::Vector{T}
handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc
end

Expand Down Expand Up @@ -261,7 +274,7 @@ mutable struct InferenceState
currbb::Int
currpc::Int
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handler_info::Union{Nothing,HandlerInfo}
handler_info::Union{Nothing,HandlerInfo{TryCatchFrame}}
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
Expand Down Expand Up @@ -318,7 +331,7 @@ mutable struct InferenceState

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_info = compute_trycatch(code)
handler_info = ComputeTryCatch{TryCatchFrame}()(code)
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
Expand Down Expand Up @@ -421,10 +434,16 @@ is_inferred(result::InferenceResult) = result.result !== nothing

was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND

compute_trycatch(ir::IRCode) = compute_trycatch(ir.stmts.stmt, ir.cfg.blocks)
struct ComputeTryCatch{T<:Handler} end

const compute_trycatch = ComputeTryCatch{SimpleHandler}()

(compute_trycatch::ComputeTryCatch{SimpleHandler})(ir::IRCode) =
compute_trycatch(ir.stmts.stmt, ir.cfg.blocks)

"""
compute_trycatch(code, [, bbs]) -> handler_info::Union{Nothing,HandlerInfo}
(::ComputeTryCatch{Handler})(code, [, bbs]) -> handler_info::Union{Nothing,HandlerInfo{Handler}}
const compute_trycatch = ComputeTryCatch{SimpleHandler}()
Given the code of a function, compute, at every statement, the current
try/catch handler, and the current exception stack top. This function returns
Expand All @@ -433,9 +452,9 @@ a tuple of:
1. `handler_info.handler_at`: A statement length vector of tuples
`(catch_handler, exception_stack)`, which are indices into `handlers`
2. `handler_info.handlers`: A `TryCatchFrame` vector of handlers
2. `handler_info.handlers`: A `Handler` vector of handlers
"""
function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothing}=nothing)
function (::ComputeTryCatch{Handler})(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothing}=nothing) where Handler
# The goal initially is to record the frame like this for the state at exit:
# 1: (enter 3) # == 0
# 3: (expr) # == 1
Expand All @@ -454,10 +473,10 @@ function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothi
stmt = code[pc]
if isa(stmt, EnterNode)
(;handlers, handler_at) = handler_info =
(handler_info === nothing ? HandlerInfo(TryCatchFrame[], fill((0, 0), n)) : handler_info)
(handler_info === nothing ? HandlerInfo{Handler}(Handler[], fill((0, 0), n)) : handler_info)
l = stmt.catch_dest
(bbs !== nothing) && (l = first(bbs[l].stmts))
push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc))
push!(handlers, Handler(stmt, pc))
handler_id = length(handlers)
handler_at[pc + 1] = (handler_id, 0)
push!(ip, pc + 1)
Expand Down Expand Up @@ -526,7 +545,7 @@ function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothi
end
cur_hand = cur_stacks[1]
for i = 1:l
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
cur_hand = handler_at[get_enter_idx(handlers[cur_hand])][1]
end
cur_stacks = (cur_hand, cur_stacks[2])
cur_stacks == (0, 0) && break
Expand Down
8 changes: 5 additions & 3 deletions Compiler/src/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -801,9 +801,11 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
has_pinode[id] = false
enter_idx = idx
while (handler = gethandler(handler_info, enter_idx)) !== nothing
(; enter_idx) = handler
leave_block = block_for_inst(cfg, (code[enter_idx]::EnterNode).catch_dest)
cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, new_phic_nodes[leave_block])
enter_idx = get_enter_idx(handler)
enter_node = code[enter_idx]::EnterNode
leave_block = block_for_inst(cfg, enter_node.catch_dest)
cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id,
new_phic_nodes[leave_block])
if cidx !== nothing
node = thisdef ? UpsilonNode(thisval) : UpsilonNode()
if incoming_vals[id] === UNDEF_TOKEN
Expand Down
2 changes: 1 addition & 1 deletion Compiler/test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4436,7 +4436,7 @@ let x = Tuple{Int,Any}[
#=20=# (0, Core.ReturnNode(Core.SlotNumber(3)))
]
(;handler_at, handlers) = Compiler.compute_trycatch(last.(x))
@test map(x->x[1] == 0 ? 0 : handlers[x[1]].enter_idx, handler_at) == first.(x)
@test map(x->x[1] == 0 ? 0 : Compiler.get_enter_idx(handlers[x[1]]), handler_at) == first.(x)
end

@test only(Base.return_types((Bool,)) do y
Expand Down

0 comments on commit 0551039

Please sign in to comment.