Skip to content

Commit

Permalink
EA: use is_mutation_free_argtype for the escapability check
Browse files Browse the repository at this point in the history
EA has been using `isbitstype` for type-level escapability checks, but
a better criterion (`is_mutation_free`) is available these days, so we
would like to use that instead.
  • Loading branch information
aviatesk committed Oct 7, 2024
1 parent c7071e1 commit 72ba7ca
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 63 deletions.
10 changes: 5 additions & 5 deletions base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ using ._TOP_MOD: # Base definitions
isempty, ismutabletype, keys, last, length, max, min, missing, pop!, push!, pushfirst!,
unwrap_unionall, !, !=, !==, &, *, +, -, :, <, <<, =>, >, |, , , , , , , ,
using Core.Compiler: # Core.Compiler specific definitions
Bottom, IRCode, IR_FLAG_NOTHROW, InferenceResult, SimpleInferenceLattice,
AbstractLattice, Bottom, IRCode, IR_FLAG_NOTHROW, InferenceResult, SimpleInferenceLattice,
argextype, fieldcount_noerror, hasintersect, has_flag, intrinsic_nothrow,
is_meta_expr_head, isbitstype, isexpr, println, setfield!_nothrow, singleton_type,
try_compute_field, try_compute_fieldidx, widenconst, , AbstractLattice
is_meta_expr_head, is_mutation_free_argtype, isexpr, println, setfield!_nothrow,
singleton_type, try_compute_field, try_compute_fieldidx, widenconst,

include(x) = _TOP_MOD.include(@__MODULE__, x)
if _TOP_MOD === Core.Compiler
Expand Down Expand Up @@ -859,7 +859,7 @@ function add_escape_change!(astate::AnalysisState, @nospecialize(x), xinfo::Esca
xinfo ===&& return nothing # performance optimization
xidx = iridx(x, astate.estate)
if xidx !== nothing
if force || !isbitstype(widenconst(argextype(x, astate.ir)))
if force || !is_mutation_free_argtype(argextype(x, astate.ir))
push!(astate.changes, EscapeChange(xidx, xinfo))
end
end
Expand All @@ -869,7 +869,7 @@ end
function add_liveness_change!(astate::AnalysisState, @nospecialize(x), livepc::Int)
xidx = iridx(x, astate.estate)
if xidx !== nothing
if !isbitstype(widenconst(argextype(x, astate.ir)))
if !is_mutation_free_argtype(argextype(x, astate.ir))
push!(astate.changes, LivenessChange(xidx, livepc))
end
end
Expand Down
116 changes: 58 additions & 58 deletions test/compiler/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ end

let # typeassert
result = code_escapes((Any,)) do x
y = x::String
y = x::Base.RefValue{Any}
return y
end
r = only(findall(isreturn, result.ir.stmts.stmt))
Expand All @@ -305,11 +305,6 @@ end
r = only(findall(isreturn, result.ir.stmts.stmt))
@test has_return_escape(result.state[Argument(2)], r)
@test !has_all_escape(result.state[Argument(2)])

result = code_escapes((Module,)) do m
isdefined(m, 10) # throws
end
@test has_thrown_escape(result.state[Argument(2)])
end
end

Expand Down Expand Up @@ -685,8 +680,8 @@ end
@test has_all_escape(result.state[Argument(2)])
end
let result = @eval EATModule() begin
const Rx = SafeRef{String}("Rx")
$code_escapes((String,)) do s
const Rx = SafeRef{Any}(nothing)
$code_escapes((Base.RefValue{String},)) do s
setfield!(Rx, :x, s)
Core.sizeof(Rx[])
end
Expand All @@ -712,7 +707,7 @@ end
# ------------

# field escape should propagate to :new arguments
let result = code_escapes((String,)) do a
let result = code_escapes((Base.RefValue{String},)) do a
o = SafeRef(a)
Core.donotdelete(o)
return o[]
Expand All @@ -722,7 +717,7 @@ end
@test has_return_escape(result.state[Argument(2)], r)
@test is_load_forwardable(result.state[SSAValue(i)])
end
let result = code_escapes((String,)) do a
let result = code_escapes((Base.RefValue{String},)) do a
t = SafeRef((a,))
f = t[][1]
return f
Expand All @@ -731,9 +726,8 @@ end
r = only(findall(isreturn, result.ir.stmts.stmt))
@test has_return_escape(result.state[Argument(2)], r)
@test is_load_forwardable(result.state[SSAValue(i)])
result.state[SSAValue(i)].AliasInfo
end
let result = code_escapes((String, String)) do a, b
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String})) do a, b
obj = SafeRefs(a, b)
Core.donotdelete(obj)
fld1 = obj[1]
Expand All @@ -748,31 +742,31 @@ end
end

# field escape should propagate to `setfield!` argument
let result = code_escapes((String,)) do a
o = SafeRef("foo")
let result = code_escapes((Base.RefValue{String},)) do a
o = SafeRef(Ref("foo"))
Core.donotdelete(o)
o[] = a
return o[]
end
i = only(findall(isnew, result.ir.stmts.stmt))
i = last(findall(isnew, result.ir.stmts.stmt))
r = only(findall(isreturn, result.ir.stmts.stmt))
@test has_return_escape(result.state[Argument(2)], r)
@test is_load_forwardable(result.state[SSAValue(i)])
end
# propagate escape information imposed on return value of `setfield!` call
let result = code_escapes((String,)) do a
obj = SafeRef("foo")
let result = code_escapes((Base.RefValue{String},)) do a
obj = SafeRef(Ref("foo"))
Core.donotdelete(obj)
return (obj[] = a)
end
i = only(findall(isnew, result.ir.stmts.stmt))
i = last(findall(isnew, result.ir.stmts.stmt))
r = only(findall(isreturn, result.ir.stmts.stmt))
@test has_return_escape(result.state[Argument(2)], r)
@test is_load_forwardable(result.state[SSAValue(i)])
end

# nested allocations
let result = code_escapes((String,)) do a
let result = code_escapes((Base.RefValue{String},)) do a
o1 = SafeRef(a)
o2 = SafeRef(o1)
return o2[]
Expand All @@ -787,7 +781,7 @@ end
end
end
end
let result = code_escapes((String,)) do a
let result = code_escapes((Base.RefValue{String},)) do a
o1 = (a,)
o2 = (o1,)
return o2[1]
Expand All @@ -802,7 +796,7 @@ end
end
end
end
let result = code_escapes((String,)) do a
let result = code_escapes((Base.RefValue{String},)) do a
o1 = SafeRef(a)
o2 = SafeRef(o1)
o1′ = o2[]
Expand Down Expand Up @@ -844,7 +838,7 @@ end
@test has_return_escape(result.state[SSAValue(i)], r)
end
end
let result = code_escapes((String,)) do x
let result = code_escapes((Base.RefValue{String},)) do x
o = Ref(x)
Core.donotdelete(o)
broadcast(identity, o)
Expand Down Expand Up @@ -892,7 +886,7 @@ end
end
end
# when ϕ-node merges values with different types
let result = code_escapes((Bool,String,String,String)) do cond, x, y, z
let result = code_escapes((Bool,Base.RefValue{String},Base.RefValue{String},Base.RefValue{String})) do cond, x, y, z
local out
if cond
ϕ = SafeRef(x)
Expand All @@ -904,7 +898,7 @@ end
end
r = only(findall(isreturn, result.ir.stmts.stmt))
t = only(findall(iscall((result.ir, throw)), result.ir.stmts.stmt))
ϕ = only(findall(==(Union{SafeRef{String},SafeRefs{String,String}}), result.ir.stmts.type))
ϕ = only(findall(==(Union{SafeRef{Base.RefValue{String}},SafeRefs{Base.RefValue{String},Base.RefValue{String}}}), result.ir.stmts.type))
@test has_return_escape(result.state[Argument(3)], r) # x
@test !has_return_escape(result.state[Argument(4)], r) # y
@test has_return_escape(result.state[Argument(5)], r) # z
Expand Down Expand Up @@ -1038,7 +1032,7 @@ end
end
# alias via typeassert
let result = code_escapes((Any,)) do a
r = a::String
r = a::Base.RefValue{String}
return r
end
r = only(findall(isreturn, result.ir.stmts.stmt))
Expand Down Expand Up @@ -1077,11 +1071,11 @@ end
@test has_all_escape(result.state[Argument(3)]) # a
end
# alias via ϕ-node
let result = code_escapes((Bool,String)) do cond, x
let result = code_escapes((Bool,Base.RefValue{String})) do cond, x
if cond
ϕ2 = ϕ1 = SafeRef("foo")
ϕ2 = ϕ1 = SafeRef(Ref("foo"))
else
ϕ2 = ϕ1 = SafeRef("bar")
ϕ2 = ϕ1 = SafeRef(Ref("bar"))
end
ϕ2[] = x
return ϕ1[]
Expand All @@ -1094,14 +1088,16 @@ end
@test is_load_forwardable(result.state[SSAValue(i)])
end
for i in findall(isnew, result.ir.stmts.stmt)
@test is_load_forwardable(result.state[SSAValue(i)])
if result.ir[SSAValue(i)][:type] <: SafeRef
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
end
let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x
let result = code_escapes((Bool,Bool,Base.RefValue{String})) do cond1, cond2, x
if cond1
ϕ2 = ϕ1 = SafeRef("foo")
ϕ2 = ϕ1 = SafeRef(Ref("foo"))
else
ϕ2 = ϕ1 = SafeRef("bar")
ϕ2 = ϕ1 = SafeRef(Ref("bar"))
end
cond2 && (ϕ2[] = x)
return ϕ1[]
Expand All @@ -1114,12 +1110,14 @@ end
@test is_load_forwardable(result.state[SSAValue(i)])
end
for i in findall(isnew, result.ir.stmts.stmt)
@test is_load_forwardable(result.state[SSAValue(i)])
if result.ir[SSAValue(i)][:type] <: SafeRef
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
end
# alias via π-node
let result = code_escapes((Any,)) do x
if isa(x, String)
if isa(x, Base.RefValue{String})
return x
end
throw("error!")
Expand Down Expand Up @@ -1213,7 +1211,7 @@ end

# conservatively handle unknown field:
# all fields should be escaped, but the allocation itself doesn't need to be escaped
let result = code_escapes((String, Symbol)) do a, fld
let result = code_escapes((Base.RefValue{String}, Symbol)) do a, fld
obj = SafeRef(a)
return getfield(obj, fld)
end
Expand All @@ -1222,7 +1220,7 @@ end
@test has_return_escape(result.state[Argument(2)], r) # a
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, String, Symbol)) do a, b, fld
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String}, Symbol)) do a, b, fld
obj = SafeRefs(a, b)
return getfield(obj, fld) # should escape both `a` and `b`
end
Expand All @@ -1232,7 +1230,7 @@ end
@test has_return_escape(result.state[Argument(3)], r) # b
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, String, Int)) do a, b, idx
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String}, Int)) do a, b, idx
obj = SafeRefs(a, b)
return obj[idx] # should escape both `a` and `b`
end
Expand All @@ -1242,33 +1240,33 @@ end
@test has_return_escape(result.state[Argument(3)], r) # b
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, String, Symbol)) do a, b, fld
obj = SafeRefs("a", "b")
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String}, Symbol)) do a, b, fld
obj = SafeRefs(Ref("a"), Ref("b"))
setfield!(obj, fld, a)
return obj[2] # should escape `a`
end
i = only(findall(isnew, result.ir.stmts.stmt))
i = last(findall(isnew, result.ir.stmts.stmt))
r = only(findall(isreturn, result.ir.stmts.stmt))
@test has_return_escape(result.state[Argument(2)], r) # a
@test !has_return_escape(result.state[Argument(3)], r) # b
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, Symbol)) do a, fld
obj = SafeRefs("a", "b")
let result = code_escapes((Base.RefValue{String}, Symbol)) do a, fld
obj = SafeRefs(Ref("a"), Ref("b"))
setfield!(obj, fld, a)
return obj[1] # this should escape `a`
end
i = only(findall(isnew, result.ir.stmts.stmt))
i = last(findall(isnew, result.ir.stmts.stmt))
r = only(findall(isreturn, result.ir.stmts.stmt))
@test has_return_escape(result.state[Argument(2)], r) # a
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, String, Int)) do a, b, idx
obj = SafeRefs("a", "b")
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String}, Int)) do a, b, idx
obj = SafeRefs(Ref("a"), Ref("b"))
obj[idx] = a
return obj[2] # should escape `a`
end
i = only(findall(isnew, result.ir.stmts.stmt))
i = last(findall(isnew, result.ir.stmts.stmt))
r = only(findall(isreturn, result.ir.stmts.stmt))
@test has_return_escape(result.state[Argument(2)], r) # a
@test !has_return_escape(result.state[Argument(3)], r) # b
Expand All @@ -1280,7 +1278,7 @@ end

let result = @eval EATModule() begin
@noinline getx(obj) = obj[]
$code_escapes((String,)) do a
$code_escapes((Base.RefValue{String},)) do a
obj = SafeRef(a)
fld = getx(obj)
return fld
Expand All @@ -1294,8 +1292,8 @@ end
end

# TODO interprocedural alias analysis
let result = code_escapes((SafeRef{String},)) do s
s[] = "bar"
let result = code_escapes((SafeRef{Base.RefValue{String}},)) do s
s[] = Ref("bar")
global GV = s[]
nothing
end
Expand Down Expand Up @@ -1335,7 +1333,7 @@ end
let result = @eval EATModule() begin
@noinline mysetindex!(x, a) = x[1] = a
const Ax = Vector{Any}(undef, 1)
$code_escapes((String,)) do s
$code_escapes((Base.RefValue{String},)) do s
mysetindex!(Ax, s)
end
end
Expand Down Expand Up @@ -1391,11 +1389,11 @@ end
end

# handle conflicting field information correctly
let result = code_escapes((Bool,String,String,)) do cnd, baz, qux
let result = code_escapes((Bool,Base.RefValue{String},Base.RefValue{String},)) do cnd, baz, qux
if cnd
o = SafeRef("foo")
o = SafeRef(Ref("foo"))
else
o = SafeRefs("bar", baz)
o = SafeRefs(Ref("bar"), baz)
r = getfield(o, 2)
end
if cnd
Expand All @@ -1409,12 +1407,14 @@ end
@test has_return_escape(result.state[Argument(3)], r) # baz
@test has_return_escape(result.state[Argument(4)], r) # qux
for new in findall(isnew, result.ir.stmts.stmt)
@test is_load_forwardable(result.state[SSAValue(new)])
if !(result.ir[SSAValue(new)][:type] <: Base.RefValue)
@test is_load_forwardable(result.state[SSAValue(new)])
end
end
end
let result = code_escapes((Bool,String,String,)) do cnd, baz, qux
let result = code_escapes((Bool,Base.RefValue{String},Base.RefValue{String},)) do cnd, baz, qux
if cnd
o = SafeRefs("foo", "bar")
o = SafeRefs(Ref("foo"), Ref("bar"))
r = setfield!(o, 2, baz)
else
o = SafeRef(qux)
Expand Down Expand Up @@ -2141,9 +2141,9 @@ end
# propagate escapes imposed on call arguments
@noinline broadcast_noescape2(b) = broadcast(identity, b)
let result = code_escapes() do
broadcast_noescape2(Ref("Hi"))
broadcast_noescape2(Ref(Ref("Hi")))
end
i = only(findall(isnew, result.ir.stmts.stmt))
i = last(findall(isnew, result.ir.stmts.stmt))
@test_broken !has_return_escape(result.state[SSAValue(i)]) # TODO interprocedural alias analysis
@test !has_thrown_escape(result.state[SSAValue(i)])
end
Expand Down

0 comments on commit 72ba7ca

Please sign in to comment.