Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Teach compiler about partitioned bindings #56299

Merged
merged 6 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
365 changes: 329 additions & 36 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions base/compiler/cicache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ WorldRange(r::UnitRange) = WorldRange(first(r), last(r))
first(wr::WorldRange) = wr.min_world
last(wr::WorldRange) = wr.max_world
in(world::UInt, wr::WorldRange) = wr.min_world <= world <= wr.max_world
min_world(wr::WorldRange) = first(wr)
max_world(wr::WorldRange) = last(wr)

function intersect(a::WorldRange, b::WorldRange)
ret = WorldRange(max(a.min_world, b.min_world), min(a.max_world, b.max_world))
Expand Down
10 changes: 6 additions & 4 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,10 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
isa(stmt, GotoNode) && return (true, false, true)
isa(stmt, GotoIfNot) && return (true, false, ⊑(𝕃ₒ, argextype(stmt.cond, src), Bool))
if isa(stmt, GlobalRef)
nothrow = consistent = isdefinedconst_globalref(stmt)
return (consistent, nothrow, nothrow)
# Modeled more precisely in abstract_eval_globalref. In general, if a
# GlobalRef was moved to statement position, it is probably not `const`,
# so we can't say much about it anyway.
return (false, false, false)
elseif isa(stmt, Expr)
(; head, args) = stmt
if head === :static_parameter
Expand Down Expand Up @@ -444,7 +446,7 @@ function argextype(
elseif isa(x, QuoteNode)
return Const(x.value)
elseif isa(x, GlobalRef)
return abstract_eval_globalref_type(x)
return abstract_eval_globalref_type(x, src)
elseif isa(x, PhiNode) || isa(x, PhiCNode) || isa(x, UpsilonNode)
return Any
elseif isa(x, PiNode)
Expand Down Expand Up @@ -1277,7 +1279,7 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
# types of call arguments only once `slot2reg` converts this `IRCode` to the SSA form
# and eliminates slots (see below)
argtypes = sv.slottypes
return IRCode(stmts, sv.cfg, di, argtypes, meta, sv.sptypes)
return IRCode(stmts, sv.cfg, di, argtypes, meta, sv.sptypes, WorldRange(ci.min_world, ci.max_world))
end

function process_meta!(meta::Vector{Expr}, @nospecialize stmt)
Expand Down
5 changes: 0 additions & 5 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1696,11 +1696,6 @@ function early_inline_special_case(ir::IRCode, stmt::Expr, flag::UInt32,
if has_flag(flag, IR_FLAG_NOTHROW)
return SomeCase(quoted(val))
end
elseif f === Core.get_binding_type
length(argtypes) == 3 || return nothing
if get_binding_type_effect_free(argtypes[2], argtypes[3])
return SomeCase(quoted(val))
end
end
end
if f === compilerbarrier
Expand Down
9 changes: 6 additions & 3 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -430,22 +430,25 @@ struct IRCode
cfg::CFG
new_nodes::NewNodeStream
meta::Vector{Expr}
valid_worlds::WorldRange

function IRCode(stmts::InstructionStream, cfg::CFG, debuginfo::DebugInfoStream, argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{VarState})
function IRCode(stmts::InstructionStream, cfg::CFG, debuginfo::DebugInfoStream,
argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{VarState},
valid_worlds=WorldRange(typemin(UInt), typemax(UInt)))
return new(stmts, argtypes, sptypes, debuginfo, cfg, NewNodeStream(), meta)
end
function IRCode(ir::IRCode, stmts::InstructionStream, cfg::CFG, new_nodes::NewNodeStream)
di = ir.debuginfo
@assert di.codelocs === stmts.line
return new(stmts, ir.argtypes, ir.sptypes, di, cfg, new_nodes, ir.meta)
return new(stmts, ir.argtypes, ir.sptypes, di, cfg, new_nodes, ir.meta, ir.valid_worlds)
end
global function copy(ir::IRCode)
di = ir.debuginfo
stmts = copy(ir.stmts)
di = copy(di)
di.edges = copy(di.edges)
di.codelocs = stmts.line
return new(stmts, copy(ir.argtypes), copy(ir.sptypes), di, copy(ir.cfg), copy(ir.new_nodes), copy(ir.meta))
return new(stmts, copy(ir.argtypes), copy(ir.sptypes), di, copy(ir.cfg), copy(ir.new_nodes), copy(ir.meta), ir.valid_worlds)
end
end

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function inflate_ir!(ci::CodeInfo, sptypes::Vector{VarState}, argtypes::Vector{A
di = DebugInfoStream(nothing, ci.debuginfo, nstmts)
stmts = InstructionStream(code, ssavaluetypes, info, di.codelocs, ci.ssaflags)
meta = Expr[]
return IRCode(stmts, cfg, di, argtypes, meta, sptypes)
return IRCode(stmts, cfg, di, argtypes, meta, sptypes, WorldRange(ci.min_world, ci.max_world))
end

"""
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,9 @@ function lift_leaves(compact::IncrementalCompact, field::Int,
elseif isa(leaf, QuoteNode)
leaf = leaf.value
elseif isa(leaf, GlobalRef)
mod, name = leaf.mod, leaf.name
if isdefined(mod, name) && isconst(mod, name)
leaf = getglobal(mod, name)
typ = argextype(leaf, compact)
if isa(typ, Const)
leaf = typ.val
else
return nothing
end
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ function typ_for_val(@nospecialize(x), ci::CodeInfo, ir::IRCode, idx::Int, slott
end
return (ci.ssavaluetypes::Vector{Any})[idx]
end
isa(x, GlobalRef) && return abstract_eval_globalref_type(x)
isa(x, GlobalRef) && return abstract_eval_globalref_type(x, ci)
isa(x, SSAValue) && return (ci.ssavaluetypes::Vector{Any})[x.id]
isa(x, Argument) && return slottypes[x.n]
isa(x, NewSSAValue) && return types(ir)[new_to_regular(x, length(ir.stmts))]
Expand Down
180 changes: 31 additions & 149 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,7 @@ end
if isa(a1, DataType) && !isabstracttype(a1)
if a1 === Module
hasintersect(widenconst(sym), Symbol) || return Bottom
if isa(sym, Const) && isa(sym.val, Symbol) && isa(arg1, Const) &&
isdefinedconst_globalref(GlobalRef(arg1.val::Module, sym.val::Symbol))
return Const(true)
end
# isa(sym, Const) case intercepted in abstract interpretation
elseif isa(sym, Const)
val = sym.val
if isa(val, Symbol)
Expand Down Expand Up @@ -1160,7 +1157,9 @@ end
if isa(sv, Module)
setfield && return Bottom
if isa(nv, Symbol)
return abstract_eval_global(sv, nv)
# In ordinary inference, this case is intercepted early and
# re-routed to `getglobal`.
return Any
end
return Bottom
end
Expand Down Expand Up @@ -1407,8 +1406,9 @@ end
elseif ff === Core.modifyglobal!
o = unwrapva(argtypes[2])
f = unwrapva(argtypes[3])
RT = modifyglobal!_tfunc(𝕃ᵢ, o, f, Any, Any, Symbol)
TF = getglobal_tfunc(𝕃ᵢ, o, f, Symbol)
GT = abstract_eval_get_binding_type(interp, sv, o, f).rt
RT = isa(GT, Const) ? Pair{GT.val, GT.val} : Pair
TF = isa(GT, Const) ? GT.val : Any
elseif ff === Core.memoryrefmodify!
o = unwrapva(argtypes[2])
RT = memoryrefmodify!_tfunc(𝕃ᵢ, o, Any, Any, Symbol, Bool)
Expand Down Expand Up @@ -2277,20 +2277,6 @@ function _builtin_nothrow(𝕃::AbstractLattice, @nospecialize(f::Builtin), argt
elseif f === typeassert
na == 2 || return false
return typeassert_nothrow(𝕃, argtypes[1], argtypes[2])
elseif f === getglobal
if na == 2
return getglobal_nothrow(argtypes[1], argtypes[2])
elseif na == 3
return getglobal_nothrow(argtypes[1], argtypes[2], argtypes[3])
end
return false
elseif f === setglobal!
if na == 3
return setglobal!_nothrow(argtypes[1], argtypes[2], argtypes[3])
elseif na == 4
return setglobal!_nothrow(argtypes[1], argtypes[2], argtypes[3], argtypes[4])
end
return false
elseif f === Core.get_binding_type
na == 2 || return false
return get_binding_type_nothrow(𝕃, argtypes[1], argtypes[2])
Expand Down Expand Up @@ -2473,7 +2459,8 @@ function getfield_effects(𝕃::AbstractLattice, argtypes::Vector{Any}, @nospeci
end
end
if hasintersect(widenconst(obj), Module)
inaccessiblememonly = getglobal_effects(argtypes, rt).inaccessiblememonly
# Modeled more precisely in abstract_eval_getglobal
inaccessiblememonly = ALWAYS_FALSE
elseif is_mutation_free_argtype(obj)
inaccessiblememonly = ALWAYS_TRUE
else
Expand All @@ -2482,24 +2469,7 @@ function getfield_effects(𝕃::AbstractLattice, argtypes::Vector{Any}, @nospeci
return Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly, noub)
end

function getglobal_effects(argtypes::Vector{Any}, @nospecialize(rt))
2 ≤ length(argtypes) ≤ 3 || return EFFECTS_THROWS
consistent = inaccessiblememonly = ALWAYS_FALSE
nothrow = false
M, s = argtypes[1], argtypes[2]
if (length(argtypes) == 3 ? getglobal_nothrow(M, s, argtypes[3]) : getglobal_nothrow(M, s))
nothrow = true
# typeasserts below are already checked in `getglobal_nothrow`
Mval, sval = (M::Const).val::Module, (s::Const).val::Symbol
if isconst(Mval, sval)
consistent = ALWAYS_TRUE
if is_mutation_free_argtype(rt)
inaccessiblememonly = ALWAYS_TRUE
end
end
end
return Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly)
end


"""
builtin_effects(𝕃::AbstractLattice, f::Builtin, argtypes::Vector{Any}, rt) -> Effects
Expand All @@ -2525,11 +2495,13 @@ function builtin_effects(𝕃::AbstractLattice, @nospecialize(f::Builtin), argty
if f === isdefined
return isdefined_effects(𝕃, argtypes)
elseif f === getglobal
return getglobal_effects(argtypes, rt)
2 ≤ length(argtypes) ≤ 3 || return EFFECTS_THROWS
# Modeled more precisely in abstract_eval_getglobal
return Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE, nothrow=false, inaccessiblememonly=ALWAYS_FALSE)
elseif f === Core.get_binding_type
length(argtypes) == 2 || return EFFECTS_THROWS
effect_free = get_binding_type_effect_free(argtypes[1], argtypes[2]) ? ALWAYS_TRUE : ALWAYS_FALSE
return Effects(EFFECTS_TOTAL; effect_free)
# Modeled more precisely in abstract_eval_get_binding_type
return Effects(EFFECTS_TOTAL; effect_free=ALWAYS_FALSE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nothrow=false instead of effect_free=ALWAYS_FALSE? This code path might not be too important though.

elseif f === compilerbarrier
length(argtypes) == 2 || return Effects(EFFECTS_THROWS; consistent=ALWAYS_FALSE)
setting = argtypes[1]
Expand Down Expand Up @@ -3070,118 +3042,28 @@ function typename_static(@nospecialize(t))
return isType(t) ? _typename(t.parameters[1]) : Core.TypeName
end

function global_order_nothrow(@nospecialize(o), loading::Bool, storing::Bool)
o isa Const || return false
function global_order_exct(@nospecialize(o), loading::Bool, storing::Bool)
if !(o isa Const)
if o === Symbol
return ConcurrencyViolationError
elseif !hasintersect(o, Symbol)
return TypeError
else
return Union{ConcurrencyViolationError, TypeError}
end
end
sym = o.val
if sym isa Symbol
order = get_atomic_order(sym, loading, storing)
return order !== MEMORY_ORDER_INVALID && order !== MEMORY_ORDER_NOTATOMIC
end
return false
end
@nospecs function getglobal_nothrow(M, s, o)
global_order_nothrow(o, #=loading=#true, #=storing=#false) || return false
return getglobal_nothrow(M, s)
end
@nospecs function getglobal_nothrow(M, s)
if M isa Const && s isa Const
M, s = M.val, s.val
if M isa Module && s isa Symbol
return isdefinedconst_globalref(GlobalRef(M, s))
end
end
return false
end
@nospecs function getglobal_tfunc(𝕃::AbstractLattice, M, s, order=Symbol)
if M isa Const && s isa Const
M, s = M.val, s.val
if M isa Module && s isa Symbol
return abstract_eval_global(M, s)
end
return Bottom
elseif !(hasintersect(widenconst(M), Module) && hasintersect(widenconst(s), Symbol))
return Bottom
end
T = get_binding_type_tfunc(𝕃, M, s)
T isa Const && return T.val
return Any
end
@nospecs function setglobal!_tfunc(𝕃::AbstractLattice, M, s, v, order=Symbol)
if !(hasintersect(widenconst(M), Module) && hasintersect(widenconst(s), Symbol))
return Bottom
end
return v
end
@nospecs function swapglobal!_tfunc(𝕃::AbstractLattice, M, s, v, order=Symbol)
setglobal!_tfunc(𝕃, M, s, v) === Bottom && return Bottom
return getglobal_tfunc(𝕃, M, s)
end
@nospecs function modifyglobal!_tfunc(𝕃::AbstractLattice, M, s, op, v, order=Symbol)
T = get_binding_type_tfunc(𝕃, M, s)
T === Bottom && return Bottom
T isa Const || return Pair
T = T.val
return Pair{T, T}
end
@nospecs function replaceglobal!_tfunc(𝕃::AbstractLattice, M, s, x, v, success_order=Symbol, failure_order=Symbol)
v = setglobal!_tfunc(𝕃, M, s, v)
v === Bottom && return Bottom
T = get_binding_type_tfunc(𝕃, M, s)
T === Bottom && return Bottom
T isa Const || return ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T
T = T.val
return ccall(:jl_apply_cmpswap_type, Any, (Any,), T)
end
@nospecs function setglobalonce!_tfunc(𝕃::AbstractLattice, M, s, v, success_order=Symbol, failure_order=Symbol)
setglobal!_tfunc(𝕃, M, s, v) === Bottom && return Bottom
return Bool
end

add_tfunc(Core.getglobal, 2, 3, getglobal_tfunc, 1)
add_tfunc(Core.setglobal!, 3, 4, setglobal!_tfunc, 3)
add_tfunc(Core.swapglobal!, 3, 4, swapglobal!_tfunc, 3)
add_tfunc(Core.modifyglobal!, 4, 5, modifyglobal!_tfunc, 3)
add_tfunc(Core.replaceglobal!, 4, 6, replaceglobal!_tfunc, 3)
add_tfunc(Core.setglobalonce!, 3, 5, setglobalonce!_tfunc, 3)

@nospecs function setglobal!_nothrow(M, s, newty, o)
global_order_nothrow(o, #=loading=#false, #=storing=#true) || return false
return setglobal!_nothrow(M, s, newty)
end
@nospecs function setglobal!_nothrow(M, s, newty)
if M isa Const && s isa Const
M, s = M.val, s.val
if isa(M, Module) && isa(s, Symbol)
return global_assignment_nothrow(M, s, newty)
end
end
return false
end

function global_assignment_nothrow(M::Module, s::Symbol, @nospecialize(newty))
if !isconst(M, s)
ty = ccall(:jl_get_binding_type, Any, (Any, Any), M, s)
return ty isa Type && widenconst(newty) <: ty
end
return false
end

@nospecs function get_binding_type_effect_free(M, s)
if M isa Const && s isa Const
M, s = M.val, s.val
if M isa Module && s isa Symbol
return ccall(:jl_get_binding_type, Any, (Any, Any), M, s) !== nothing
if order !== MEMORY_ORDER_INVALID && order !== MEMORY_ORDER_NOTATOMIC
return Union{}
else
return ConcurrencyViolationError
end
else
return TypeError
end
return false
end
@nospecs function get_binding_type_tfunc(𝕃::AbstractLattice, M, s)
if get_binding_type_effect_free(M, s)
return Const(Core.get_binding_type((M::Const).val::Module, (s::Const).val::Symbol))
end
return Type
end
add_tfunc(Core.get_binding_type, 2, 2, get_binding_type_tfunc, 0)

@nospecs function get_binding_type_nothrow(𝕃::AbstractLattice, M, s)
⊑ = partialorder(𝕃)
Expand Down
9 changes: 8 additions & 1 deletion base/runtime_internals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,20 @@ const BINDING_KIND_DECLARED = 0x7
const BINDING_KIND_GUARD = 0x8

is_some_const_binding(kind::UInt8) = (kind == BINDING_KIND_CONST || kind == BINDING_KIND_CONST_IMPORT)
is_some_imported(kind::UInt8) = (kind == BINDING_KIND_IMPLICIT || kind == BINDING_KIND_EXPLICIT || kind == BINDING_KIND_IMPORTED)
is_some_guard(kind::UInt8) = (kind == BINDING_KIND_GUARD || kind == BINDING_KIND_DECLARED || kind == BINDING_KIND_FAILED)

function lookup_binding_partition(world::UInt, b::Core.Binding)
ccall(:jl_get_binding_partition, Ref{Core.BindingPartition}, (Any, UInt), b, world)
end

function lookup_binding_partition(world::UInt, gr::Core.GlobalRef)
ccall(:jl_get_globalref_partition, Ref{Core.BindingPartition}, (Any, UInt), gr, world)
if isdefined(gr, :binding)
b = gr.binding
else
b = ccall(:jl_get_module_binding, Ref{Core.Binding}, (Any, Any, Cint), gr.mod, gr.name, true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't #=alloc=#true::Cint here be #=alloc=false::Cint given the previous definition of jl_get_globalref_partition?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That definition was probably wrong, but regardless, this code can't handle the binding not existing.

end
return lookup_binding_partition(world, b)
end

partition_restriction(bpart::Core.BindingPartition) = ccall(:jl_bpart_get_restriction_value, Any, (Any,), bpart)
Expand Down
1 change: 0 additions & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,6 @@ STATIC_INLINE int jl_bkind_is_some_guard(enum jl_partition_kind kind) JL_NOTSAFE
}

JL_DLLEXPORT jl_binding_partition_t *jl_get_binding_partition(jl_binding_t *b JL_PROPAGATES_ROOT, size_t world);
JL_DLLEXPORT jl_binding_partition_t *jl_get_globalref_partition(jl_globalref_t *gr JL_PROPAGATES_ROOT, size_t world);

EXTERN_INLINE_DECLARE uint8_t jl_bpart_get_kind(jl_binding_partition_t *bpart) JL_NOTSAFEPOINT {
return decode_restriction_kind(jl_atomic_load_relaxed(&bpart->restriction));
Expand Down
Loading
Loading