diff --git a/Compiler/src/inferencestate.jl b/Compiler/src/inferencestate.jl index 6988e74310fc5..046f98df1f41d 100644 --- a/Compiler/src/inferencestate.jl +++ b/Compiler/src/inferencestate.jl @@ -219,16 +219,29 @@ const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization required const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization required const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization required -mutable struct TryCatchFrame +abstract type Handler end +get_enter_idx(handler::Handler) = get_enter_idx_impl(handler)::Int + +mutable struct TryCatchFrame <: Handler exct scopet const enter_idx::Int scope_uses::Vector{Int} - TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx) + TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = + new(exct, scopet, enter_idx) +end +TryCatchFrame(stmt::EnterNode, pc::Int) = + TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc) +get_enter_idx_impl((; enter_idx)::TryCatchFrame) = enter_idx + +struct SimpleHandler <: Handler + enter_idx::Int end +SimpleHandler(::EnterNode, pc::Int) = SimpleHandler(pc) +get_enter_idx_impl((; enter_idx)::SimpleHandler) = enter_idx -struct HandlerInfo - handlers::Vector{TryCatchFrame} +struct HandlerInfo{T<:Handler} + handlers::Vector{T} handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc end @@ -261,7 +274,7 @@ mutable struct InferenceState currbb::Int currpc::Int ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers - handler_info::Union{Nothing,HandlerInfo} + handler_info::Union{Nothing,HandlerInfo{TryCatchFrame}} ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info # TODO: Could keep this sparsely by doing structural liveness analysis ahead of time. bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet @@ -318,7 +331,7 @@ mutable struct InferenceState currbb = currpc = 1 ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1) - handler_info = compute_trycatch(code) + handler_info = ComputeTryCatch{TryCatchFrame}()(code) nssavalues = src.ssavaluetypes::Int ssavalue_uses = find_ssavalue_uses(code, nssavalues) nstmts = length(code) @@ -421,10 +434,16 @@ is_inferred(result::InferenceResult) = result.result !== nothing was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND -compute_trycatch(ir::IRCode) = compute_trycatch(ir.stmts.stmt, ir.cfg.blocks) +struct ComputeTryCatch{T<:Handler} end + +const compute_trycatch = ComputeTryCatch{SimpleHandler}() + +(compute_trycatch::ComputeTryCatch{SimpleHandler})(ir::IRCode) = + compute_trycatch(ir.stmts.stmt, ir.cfg.blocks) """ - compute_trycatch(code, [, bbs]) -> handler_info::Union{Nothing,HandlerInfo} + (::ComputeTryCatch{Handler})(code, [, bbs]) -> handler_info::Union{Nothing,HandlerInfo{Handler}} + const compute_trycatch = ComputeTryCatch{SimpleHandler}() Given the code of a function, compute, at every statement, the current try/catch handler, and the current exception stack top. This function returns @@ -433,9 +452,9 @@ a tuple of: 1. `handler_info.handler_at`: A statement length vector of tuples `(catch_handler, exception_stack)`, which are indices into `handlers` - 2. `handler_info.handlers`: A `TryCatchFrame` vector of handlers + 2. `handler_info.handlers`: A `Handler` vector of handlers """ -function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothing}=nothing) +function (::ComputeTryCatch{Handler})(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothing}=nothing) where Handler # The goal initially is to record the frame like this for the state at exit: # 1: (enter 3) # == 0 # 3: (expr) # == 1 @@ -454,10 +473,10 @@ function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothi stmt = code[pc] if isa(stmt, EnterNode) (;handlers, handler_at) = handler_info = - (handler_info === nothing ? HandlerInfo(TryCatchFrame[], fill((0, 0), n)) : handler_info) + (handler_info === nothing ? HandlerInfo{Handler}(Handler[], fill((0, 0), n)) : handler_info) l = stmt.catch_dest (bbs !== nothing) && (l = first(bbs[l].stmts)) - push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc)) + push!(handlers, Handler(stmt, pc)) handler_id = length(handlers) handler_at[pc + 1] = (handler_id, 0) push!(ip, pc + 1) @@ -526,7 +545,7 @@ function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothi end cur_hand = cur_stacks[1] for i = 1:l - cur_hand = handler_at[handlers[cur_hand].enter_idx][1] + cur_hand = handler_at[get_enter_idx(handlers[cur_hand])][1] end cur_stacks = (cur_hand, cur_stacks[2]) cur_stacks == (0, 0) && break diff --git a/Compiler/src/ssair/slot2ssa.jl b/Compiler/src/ssair/slot2ssa.jl index 6fc87934d3bc5..80dffdab23243 100644 --- a/Compiler/src/ssair/slot2ssa.jl +++ b/Compiler/src/ssair/slot2ssa.jl @@ -801,9 +801,11 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState, has_pinode[id] = false enter_idx = idx while (handler = gethandler(handler_info, enter_idx)) !== nothing - (; enter_idx) = handler - leave_block = block_for_inst(cfg, (code[enter_idx]::EnterNode).catch_dest) - cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, new_phic_nodes[leave_block]) + enter_idx = get_enter_idx(handler) + enter_node = code[enter_idx]::EnterNode + leave_block = block_for_inst(cfg, enter_node.catch_dest) + cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, + new_phic_nodes[leave_block]) if cidx !== nothing node = thisdef ? UpsilonNode(thisval) : UpsilonNode() if incoming_vals[id] === UNDEF_TOKEN diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index c8b599adb1323..e272ff6de8d99 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -4436,7 +4436,7 @@ let x = Tuple{Int,Any}[ #=20=# (0, Core.ReturnNode(Core.SlotNumber(3))) ] (;handler_at, handlers) = Compiler.compute_trycatch(last.(x)) - @test map(x->x[1] == 0 ? 0 : handlers[x[1]].enter_idx, handler_at) == first.(x) + @test map(x->x[1] == 0 ? 0 : Compiler.get_enter_idx(handlers[x[1]]), handler_at) == first.(x) end @test only(Base.return_types((Bool,)) do y