Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline statically known method errors. #54972

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
rettype = exctype = Any
all_effects = Effects()
else
if (matches isa MethodMatches ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches)))
if !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
all_effects = Effects(all_effects; nothrow=false)
exctype = exctype ⊔ₚ MethodError
Expand Down Expand Up @@ -275,21 +274,23 @@ struct MethodMatches
applicable::Vector{Any}
info::MethodMatchInfo
valid_worlds::WorldRange
mt::MethodTable
fullmatch::Bool
end
any_ambig(info::MethodMatchInfo) = info.results.ambig
any_ambig(result::MethodLookupResult) = result.ambig
any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
any_ambig(m::MethodMatches) = any_ambig(m.info)
fully_covering(info::MethodMatchInfo) = info.fullmatch
fully_covering(m::MethodMatches) = fully_covering(m.info)

struct UnionSplitMethodMatches
applicable::Vector{Any}
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.matches)
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
fully_covering(info::UnionSplitInfo) = all(info.fullmatches)
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)

function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
max_union_splitting::Int = InferenceParams(interp).max_union_splitting,
Expand All @@ -307,7 +308,7 @@ is_union_split_eligible(𝕃::AbstractLattice, argtypes::Vector{Any}, max_union_
function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any},
@nospecialize(atype), max_methods::Int)
split_argtypes = switchtupleunion(typeinf_lattice(interp), argtypes)
infos = MethodMatchInfo[]
infos = MethodLookupResult[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
Expand All @@ -323,29 +324,29 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
if matches === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
push!(infos, MethodMatchInfo(matches))
push!(infos, matches)
for m in matches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
end
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
found = false
mt_found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
fullmatches[i] &= thisfullmatch
found = true
mt_found = true
break
end
end
if !found
if !mt_found
push!(mts, mt)
push!(fullmatches, thisfullmatch)
end
end
info = UnionSplitInfo(infos)
info = UnionSplitInfo(infos, mts, fullmatches)
return UnionSplitMethodMatches(
applicable, applicable_argtypes, info, valid_worlds, mts, fullmatches)
applicable, applicable_argtypes, info, valid_worlds)
end

function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(atype), max_methods::Int)
Expand All @@ -360,10 +361,9 @@ function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(a
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
info = MethodMatchInfo(matches)
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
return MethodMatches(
matches.matches, info, matches.valid_worlds, mt, fullmatch)
info = MethodMatchInfo(matches, mt, fullmatch)
return MethodMatches(matches.matches, info, matches.valid_worlds)
end

"""
Expand Down Expand Up @@ -584,9 +584,10 @@ function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype)
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
matches::UnionSplitMethodMatches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
matches::UnionSplitMethodMatches

for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
Expand Down
52 changes: 30 additions & 22 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ struct InliningCase
end

struct UnionSplit
fully_covered::Bool
handled_all_cases::Bool # All possible dispatches are included in the cases
fully_covered::Bool # All handled cases are fully covering
atype::DataType
cases::Vector{InliningCase}
bbs::Vector{Int}
UnionSplit(fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) =
new(fully_covered, atype, cases, Int[])
UnionSplit(handled_all_cases::Bool, fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) =
new(handled_all_cases, fully_covered, atype, cases, Int[])
end

struct InliningEdgeTracker
Expand Down Expand Up @@ -215,7 +216,7 @@ end

function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
state::CFGInliningState, params::OptimizationParams)
(; fully_covered, #=atype,=# cases, bbs) = union_split
(; handled_all_cases, fully_covered, #=atype,=# cases, bbs) = union_split
inline_into_block!(state, block_for_inst(ir, idx))
from_bbs = Int[]
delete!(state.split_targets, length(state.new_cfg_blocks))
Expand All @@ -235,7 +236,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
end
end
push!(from_bbs, length(state.new_cfg_blocks))
if !(i == length(cases) && fully_covered)
if !(i == length(cases) && (handled_all_cases && fully_covered))
# This block will have the next condition or the final else case
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx)))
push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks))
Expand All @@ -244,7 +245,10 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
end
end
# The edge from the fallback block.
fully_covered || push!(from_bbs, length(state.new_cfg_blocks))
# NOTE This edge is only required for `!handled_all_cases` and not `!fully_covered`,
# since in the latter case we inline `Core.throw_methoderror` into the fallback
# block, which is must-throw, making the subsequent code path unreachable.
!handled_all_cases && push!(from_bbs, length(state.new_cfg_blocks))
topolarity marked this conversation as resolved.
Show resolved Hide resolved
# This block will be the block everyone returns to
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx), from_bbs, orig_succs))
join_bb = length(state.new_cfg_blocks)
Expand Down Expand Up @@ -523,7 +527,7 @@ assuming their order stays the same post-discovery in `ml_matches`.
function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::Vector{Any},
union_split::UnionSplit, boundscheck::Symbol,
todo_bbs::Vector{Tuple{Int,Int}}, interp::AbstractInterpreter)
(; fully_covered, atype, cases, bbs) = union_split
(; handled_all_cases, fully_covered, atype, cases, bbs) = union_split
stmt, typ, line = compact.result[idx][:stmt], compact.result[idx][:type], compact.result[idx][:line]
join_bb = bbs[end]
pn = PhiNode()
Expand All @@ -538,7 +542,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
cond = true
nparams = fieldcount(atype)
@assert nparams == fieldcount(mtype)
if !(i == ncases && fully_covered)
if !(i == ncases && fully_covered && handled_all_cases)
for i = 1:nparams
aft, mft = fieldtype(atype, i), fieldtype(mtype, i)
# If this is always true, we don't need to check for it
Expand Down Expand Up @@ -597,14 +601,18 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
end
bb += 1
# We're now in the fall through block, decide what to do
if !fully_covered
if !handled_all_cases
ssa = insert_node_here!(compact, NewInstruction(stmt, typ, line))
push!(pn.edges, bb)
push!(pn.values, ssa)
insert_node_here!(compact, NewInstruction(GotoNode(join_bb), Any, line))
finish_current_bb!(compact, 0)
elseif !fully_covered
insert_node_here!(compact, NewInstruction(Expr(:call, GlobalRef(Core, :throw_methoderror), argexprs...), Union{}, line))
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
finish_current_bb!(compact, 0)
ncases == 0 && return insert_node_here!(compact, NewInstruction(nothing, Any, line))
end

# We're now in the join block.
return insert_node_here!(compact, NewInstruction(pn, typ, line))
end
Expand Down Expand Up @@ -1348,10 +1356,6 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
# Too many applicable methods
# Or there is a (partial?) ambiguity
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
handled_all_cases = false
continue
end
local split_fully_covered = false
for (j, match) in enumerate(meth)
Expand Down Expand Up @@ -1392,22 +1396,26 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
handled_all_cases &= handle_any_const_result!(cases,
result, match, argtypes, info, flag, state; allow_typevars=true)
end
if !fully_covered
atype = argtypes_to_type(sig.argtypes)
# We will emit an inline MethodError so we need a backedge to the MethodTable
add_uncovered_edges!(state.edges, info, atype)
end
elseif !isempty(cases)
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

return cases, (handled_all_cases & fully_covered), joint_effects
return cases, handled_all_cases, fully_covered, joint_effects
end

function handle_call!(todo::Vector{Pair{Int,Any}},
ir::IRCode, idx::Int, stmt::Expr, @nospecialize(info::CallInfo), flag::UInt32, sig::Signature,
state::InliningState)
cases = compute_inlining_cases(info, flag, sig, state)
cases === nothing && return nothing
cases, all_covered, joint_effects = cases
cases, handled_all_cases, fully_covered, joint_effects = cases
atype = argtypes_to_type(sig.argtypes)
handle_cases!(todo, ir, idx, stmt, atype, cases, all_covered, joint_effects)
handle_cases!(todo, ir, idx, stmt, atype, cases, handled_all_cases, fully_covered, joint_effects)
end

function handle_match!(cases::Vector{InliningCase},
Expand Down Expand Up @@ -1496,19 +1504,19 @@ function concrete_result_item(result::ConcreteResult, @nospecialize(info::CallIn
end

function handle_cases!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stmt::Expr,
@nospecialize(atype), cases::Vector{InliningCase}, all_covered::Bool,
@nospecialize(atype), cases::Vector{InliningCase}, handled_all_cases::Bool, fully_covered::Bool,
joint_effects::Effects)
# If we only have one case and that case is fully covered, we may either
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if all_covered && length(cases) == 1
if fully_covered && handled_all_cases && length(cases) == 1
handle_single_case!(todo, ir, idx, stmt, cases[1].item)
elseif length(cases) > 0
elseif length(cases) > 0 || handled_all_cases
isa(atype, DataType) || return nothing
for case in cases
isa(case.sig, DataType) || return nothing
end
push!(todo, idx=>UnionSplit(all_covered, atype, cases))
push!(todo, idx=>UnionSplit(handled_all_cases, fully_covered, atype, cases))
else
add_flag!(ir[SSAValue(idx)], flags_for_effects(joint_effects))
end
Expand Down
17 changes: 14 additions & 3 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ not a call to a generic function.
"""
struct MethodMatchInfo <: CallInfo
results::MethodLookupResult
mt::MethodTable
fullmatch::Bool
end
nsplit_impl(info::MethodMatchInfo) = 1
getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results)
getresult_impl(::MethodMatchInfo, ::Int) = nothing
add_uncovered_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, @nospecialize(atype)) = (!info.fullmatch && push!(edges, info.mt, atype); )

"""
info::UnionSplitInfo <: CallInfo
Expand All @@ -48,20 +51,27 @@ each partition (`info.matches::Vector{MethodMatchInfo}`).
This info is illegal on any statement that is not a call to a generic function.
"""
struct UnionSplitInfo <: CallInfo
matches::Vector{MethodMatchInfo}
matches::Vector{MethodLookupResult}
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
Comment on lines -51 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be okay to revert back to storing Vector{MethodMatchInfo} like before? The current design, where matches/mts/fullmatches can have different lengths, feels a bit tricky to handle. I understand it's meant to avoid adding duplicate edges to the same mt in add_uncovered_edges_impl, but in practice, adding duplicate edges isn't really an issue. From a complexity standpoint, I think the previous approach is preferable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you feel about:

Suggested change
matches::Vector{MethodMatchInfo}
matches::Vector{MethodLookupResult}
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
matches::Vector{MethodLookupResult}
mt_edges::Vector{@NamedTuple{mt::MethodTable, fullmatch::Bool}}

or similar?

That's more consistent with what UnionSplitMethodMatches carried before, which always had these edges de-duplicated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the previous design of UnionSplitMethodMatches also avoided adding duplicated edges. But I still feel that it's better for UnionSplitInfo or UnionSplitMethodMatches to have a simpler data structure. We can avoid duplication when adding edges.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, feel free to open a PR

end

nmatches(info::MethodMatchInfo) = length(info.results)
function nmatches(info::UnionSplitInfo)
n = 0
for mminfo in info.matches
n += nmatches(mminfo)
n += length(mminfo)
end
return n
end
nsplit_impl(info::UnionSplitInfo) = length(info.matches)
getsplit_impl(info::UnionSplitInfo, idx::Int) = getsplit_impl(info.matches[idx], 1)
getsplit_impl(info::UnionSplitInfo, idx::Int) = info.matches[idx]
getresult_impl(::UnionSplitInfo, ::Int) = nothing
function add_uncovered_edges_impl(edges::Vector{Any}, info::UnionSplitInfo, @nospecialize(atype))
for (mt, fullmatch) in zip(info.mts, info.fullmatches)
!fullmatch && push!(edges, mt, atype)
end
end

abstract type ConstResult end

Expand Down Expand Up @@ -105,6 +115,7 @@ end
nsplit_impl(info::ConstCallInfo) = nsplit(info.call)
getsplit_impl(info::ConstCallInfo, idx::Int) = getsplit(info.call, idx)
getresult_impl(info::ConstCallInfo, idx::Int) = info.results[idx]
add_uncovered_edges_impl(edges::Vector{Any}, info::ConstCallInfo, @nospecialize(atype)) = add_uncovered_edges!(edges, info.call, atype)

"""
info::MethodResultPure <: CallInfo
Expand Down
7 changes: 3 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2983,9 +2983,9 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
Expand All @@ -3001,8 +3001,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
add_backedge!(sv, edge)
end

if isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches))
if !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
rt = Bool
end
Expand Down
6 changes: 6 additions & 0 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,16 @@ abstract type CallInfo end

nsplit(info::CallInfo) = nsplit_impl(info)::Union{Nothing,Int}
getsplit(info::CallInfo, idx::Int) = getsplit_impl(info, idx)::MethodLookupResult
add_uncovered_edges!(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = add_uncovered_edges_impl(edges, info, atype)

getresult(info::CallInfo, idx::Int) = getresult_impl(info, idx)

# must implement `nsplit`, `getsplit`, and `add_uncovered_edges!` to opt in to inlining
nsplit_impl(::CallInfo) = nothing
getsplit_impl(::CallInfo, ::Int) = error("unexpected call into `getsplit`")
add_uncovered_edges_impl(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = error("unexpected call into `add_uncovered_edges!`")

# must implement `getresult` to opt in to extended lattice return information
getresult_impl(::CallInfo, ::Int) = nothing

@specialize
9 changes: 9 additions & 0 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ end
CC.nsplit_impl(info::NoinlineCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::NoinlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::NoinlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
CC.add_uncovered_edges_impl(edges::Vector{Any}, info::NoinlineCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype)

function CC.abstract_call(interp::NoinlineInterpreter,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.InferenceState, max_methods::Int)
Expand All @@ -431,6 +432,8 @@ end
@inline function inlined_usually(x, y, z)
return x * y + z
end
foo_split(x::Float64) = 1
foo_split(x::Int) = 2

# check if the inlining algorithm works as expected
let src = code_typed1((Float64,Float64,Float64)) do x, y, z
Expand All @@ -444,6 +447,7 @@ let NoinlineModule = Module()
main_func(x, y, z) = inlined_usually(x, y, z)
@eval NoinlineModule noinline_func(x, y, z) = $inlined_usually(x, y, z)
@eval OtherModule other_func(x, y, z) = $inlined_usually(x, y, z)
@eval NoinlineModule bar_split_error() = $foo_split(Core.compilerbarrier(:type, nothing))

interp = NoinlineInterpreter(Set((NoinlineModule,)))

Expand Down Expand Up @@ -473,6 +477,11 @@ let NoinlineModule = Module()
@test count(isinvoke(:inlined_usually), src.code) == 0
@test count(iscall((src, inlined_usually)), src.code) == 0
end

let src = code_typed1(NoinlineModule.bar_split_error)
@test count(iscall((src, foo_split)), src.code) == 0
@test count(iscall((src, Core.throw_methoderror)), src.code) > 0
end
end

# Make sure that Core.Compiler has enough NamedTuple infrastructure
Expand Down
Loading