Skip to content

Commit

Permalink
inference: model partially initialized structs with PartialStruct (#…
Browse files Browse the repository at this point in the history
…55297)

There is still room for improvement in the accuracy of `getfield` and
`isdefined` for structs with uninitialized fields. This commit aims to
enhance the accuracy of struct field defined-ness by propagating such
struct as `PartialStruct` in cases where fields that might be
uninitialized are confirmed to be defined. Specifically, the
improvements are made in the following situations:
1. when a `:new` expression receives arguments greater than the minimum
number of initialized fields.
2. when new information about the initialized fields of `x` can be
obtained in the `then` branch of `if isdefined(x, :f)`.

Combined with the existing optimizations, these improvements enable DCE
in scenarios such as:
```julia
julia> @noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing);

julia> @allocated broadcast_noescape1(Ref("x"))
16 # master
0  # this PR
```

One important point to note is that, as revealed in
#48999, fields and globals can revert to `undef` during
precompilation. This commit does not affect globals. Furthermore, even
for fields, the refinements made by 1. and 2. are propagated along with
data-flow, and field defined-ness information is only used when fields
are confirmed to be initialized. Therefore, the same issues as
#48999 will not occur by this commit.
  • Loading branch information
aviatesk authored Aug 20, 2024
1 parent d4bd540 commit 9650510
Show file tree
Hide file tree
Showing 8 changed files with 393 additions and 126 deletions.
86 changes: 58 additions & 28 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2006,33 +2006,64 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
return Conditional(aty.slot, thentype, elsetype)
end
elseif f === isdefined
uty = argtypes[2]
a = ssa_def_slot(fargs[2], sv)
if isa(uty, Union) && isa(a, SlotNumber)
fld = argtypes[3]
thentype = Bottom
elsetype = Bottom
for ty in uniontypes(uty)
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
if isa(cnd, Const)
if cnd.val::Bool
thentype = thentype ty
if isa(a, SlotNumber)
argtype2 = argtypes[2]
if isa(argtype2, Union)
fld = argtypes[3]
thentype = Bottom
elsetype = Bottom
for ty in uniontypes(argtype2)
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
if isa(cnd, Const)
if cnd.val::Bool
thentype = thentype ty
else
elsetype = elsetype ty
end
else
thentype = thentype ty
elsetype = elsetype ty
end
else
thentype = thentype ty
elsetype = elsetype ty
end
return Conditional(a, thentype, elsetype)
else
thentype = form_partially_defined_struct(argtype2, argtypes[3])
if thentype !== nothing
elsetype = argtype2
if rt === Const(false)
thentype = Bottom
elseif rt === Const(true)
elsetype = Bottom
end
return Conditional(a, thentype, elsetype)
end
end
return Conditional(a, thentype, elsetype)
end
end
end
@assert !isa(rt, TypeVar) "unhandled TypeVar"
return rt
end

function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name))
obj isa Const && return nothing # nothing to refine
name isa Const || return nothing
objt0 = widenconst(obj)
objt = unwrap_unionall(objt0)
objt isa DataType || return nothing
isabstracttype(objt) && return nothing
fldidx = try_compute_fieldidx(objt, name.val)
fldidx === nothing && return nothing
nminfld = datatype_min_ninitialized(objt)
if ismutabletype(objt)
fldidx == nminfld+1 || return nothing
else
fldidx > nminfld || return nothing
end
return PartialStruct(objt0, Any[fieldtype(objt0, i) for i = 1:fldidx])
end

function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta)
na = length(argtypes)
if isvarargtype(argtypes[end])
Expand Down Expand Up @@ -2573,20 +2604,18 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V
end
ats[i] = at
end
# For now, don't allow:
# - Const/PartialStruct of mutables (but still allow PartialStruct of mutables
# with `const` fields if anything refined)
# - partially initialized Const/PartialStruct
if fcount == nargs
if consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
elseif anyrefine
rt = PartialStruct(rt, ats)
if fcount == nargs && consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
elseif anyrefine || nargs > datatype_min_ninitialized(rt)
# propagate partially initialized struct as `PartialStruct` when:
# - any refinement information is available (`anyrefine`), or when
# - `nargs` is greater than `n_initialized` derived from the struct type
# information alone
rt = PartialStruct(rt, ats)
end
else
rt = refine_partial_type(rt)
Expand Down Expand Up @@ -3094,7 +3123,8 @@ end
@nospecializeinfer function widenreturn_partials(𝕃ᵢ::PartialsLattice, @nospecialize(rt), info::BestguessInfo)
if isa(rt, PartialStruct)
fields = copy(rt.fields)
local anyrefine = false
anyrefine = !isvarargtype(rt.fields[end]) &&
length(rt.fields) > datatype_min_ninitialized(unwrap_unionall(rt.typ))
𝕃 = typeinf_lattice(info.interp)
= strictpartialorder(𝕃)
for i in 1:length(fields)
Expand Down
7 changes: 6 additions & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,12 @@ struct IntermediaryCollector <: WalkerCallback
intermediaries::SPCSet
end
function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id)
if !(def isa Expr)
push!(walker_callback.intermediaries, defssa.id)
if def isa PiNode
return LiftedValue(def.val)
end
end
return nothing
end

Expand Down
42 changes: 30 additions & 12 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,23 +419,29 @@ end
else
return Bottom
end
if 1 <= idx <= datatype_min_ninitialized(a1)
if 1 idx datatype_min_ninitialized(a1)
return Const(true)
elseif a1.name === _NAMEDTUPLE_NAME
if isconcretetype(a1)
return Const(false)
else
ns = a1.parameters[1]
if isa(ns, Tuple)
return Const(1 <= idx <= length(ns))
return Const(1 idx length(ns))
end
end
elseif idx <= 0 || (!isvatuple(a1) && idx > fieldcount(a1))
elseif idx 0 || (!isvatuple(a1) && idx > fieldcount(a1))
return Const(false)
elseif isa(arg1, Const)
if !ismutabletype(a1) || isconst(a1, idx)
return Const(isdefined(arg1.val, idx))
end
elseif isa(arg1, PartialStruct)
if !isvarargtype(arg1.fields[end])
if 1 idx length(arg1.fields)
return Const(true)
end
end
elseif !isvatuple(a1)
fieldT = fieldtype(a1, idx)
if isa(fieldT, DataType) && isbitstype(fieldT)
Expand Down Expand Up @@ -989,27 +995,39 @@ end
= partialorder(𝕃)

# If we have s00 being a const, we can potentially refine our type-based analysis above
if isa(s00, Const) || isconstType(s00)
if !isa(s00, Const)
sv = (s00::DataType).parameters[1]
else
if isa(s00, Const) || isconstType(s00) || isa(s00, PartialStruct)
if isa(s00, Const)
sv = s00.val
sty = typeof(sv)
nflds = nfields(sv)
ismod = sv isa Module
elseif isa(s00, PartialStruct)
sty = unwrap_unionall(s00.typ)
nflds = fieldcount_noerror(sty)
ismod = false
else
sv = (s00::DataType).parameters[1]
sty = typeof(sv)
nflds = nfields(sv)
ismod = sv isa Module
end
if isa(name, Const)
nval = name.val
if !isa(nval, Symbol)
isa(sv, Module) && return false
ismod && return false
isa(nval, Int) || return false
end
return isdefined_tfunc(𝕃, s00, name) === Const(true)
end
boundscheck && return false

# If bounds checking is disabled and all fields are assigned,
# we may assume that we don't throw
isa(sv, Module) && return false
@assert !boundscheck
ismod && return false
name Int || name Symbol || return false
typeof(sv).name.n_uninitialized == 0 && return true
for i = (datatype_min_ninitialized(typeof(sv)) + 1):nfields(sv)
sty.name.n_uninitialized == 0 && return true
nflds === nothing && return false
for i = (datatype_min_ninitialized(sty)+1):nflds
isdefined_tfunc(𝕃, s00, Const(i)) === Const(true) || return false
end
return true
Expand Down
87 changes: 60 additions & 27 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,42 @@

# N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used
# inside the global code cache.
#
# # The type of a value might be constant
# struct Const
# val
# end
#
# struct PartialStruct
# typ
# fields::Vector{Any} # elements are other type lattice members
# end

import Core: Const, PartialStruct

"""
struct Const
val
end
The type representing a constant value.
"""
:(Const)

"""
struct PartialStruct
typ
fields::Vector{Any} # elements are other type lattice members
end
This extended lattice element is introduced when we have information about an object's
fields beyond what can be obtained from the object type. E.g. it represents a tuple where
some elements are known to be constants or a struct whose `Any`-typed field is initialized
with `Int` values.
- `typ` indicates the type of the object
- `fields` holds the lattice elements corresponding to each field of the object
If `typ` is a struct, `fields` represents the fields of the struct that are guaranteed to be
initialized. For instance, if the length of `fields` of `PartialStruct` representing a
struct with 4 fields is 3, the 4th field may not be initialized. If the length is 4, all
fields are guaranteed to be initialized.
If `typ` is a tuple, the last element of `fields` may be `Vararg`. In this case, it is
guaranteed that the number of elements in the tuple is at least `length(fields)-1`, but the
exact number of elements is unknown.
"""
:(PartialStruct)
function PartialStruct(@nospecialize(typ), fields::Vector{Any})
for i = 1:length(fields)
assert_nested_slotwrapper(fields[i])
Expand Down Expand Up @@ -57,23 +82,20 @@ end
Conditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) =
Conditional(slot_id(var), thentype, elsetype)

import Core: InterConditional
"""
cnd::InterConditional
struct InterConditional
slot::Int
thentype
elsetype
end
Similar to `Conditional`, but conveys inter-procedural constraints imposed on call arguments.
This is separate from `Conditional` to catch logic errors: the lattice element name is `InterConditional`
while processing a call, then `Conditional` everywhere else. Thus `InterConditional` does not appear in
`CompilerTypes`—these type's usages are disjoint—though we define the lattice for `InterConditional`.
"""
:(InterConditional)
import Core: InterConditional
# struct InterConditional
# slot::Int
# thentype
# elsetype
# InterConditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype)) =
# new(slot, thentype, elsetype)
# end
InterConditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) =
InterConditional(slot_id(var), thentype, elsetype)

Expand Down Expand Up @@ -447,8 +469,13 @@ end
@nospecializeinfer function (lattice::PartialsLattice, @nospecialize(a), @nospecialize(b))
if isa(a, PartialStruct)
if isa(b, PartialStruct)
if !(length(a.fields) == length(b.fields) && a.typ <: b.typ)
return false
a.typ <: b.typ || return false
if length(a.fields) length(b.fields)
if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end]))
length(a.fields) length(b.fields) || return false
else
return false
end
end
for i in 1:length(b.fields)
af = a.fields[i]
Expand All @@ -471,19 +498,25 @@ end
return isa(b, Type) && a.typ <: b
elseif isa(b, PartialStruct)
if isa(a, Const)
nf = nfields(a.val)
nf == length(b.fields) || return false
widea = widenconst(a)::DataType
wideb = widenconst(b)
wideb′ = unwrap_unionall(wideb)::DataType
widea.name === wideb′.name || return false
# We can skip the subtype check if b is a Tuple, since in that
# case, the ⊑ of the elements is sufficient.
if wideb′.name !== Tuple.name && !(widea <: wideb)
return false
if wideb′.name === Tuple.name
# We can skip the subtype check if b is a Tuple, since in that
# case, the ⊑ of the elements is sufficient.
# But for tuple comparisons, we need their lengths to be the same for now.
# TODO improve accuracy for cases when `b` contains vararg element
nfields(a.val) == length(b.fields) || return false
else
widea <: wideb || return false
# for structs we need to check that `a` has more information than `b` that may be partially initialized
n_initialized(a) length(b.fields) || return false
end
nf = nfields(a.val)
for i in 1:nf
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct
bfᵢ = b.fields[i]
if i == nf
bfᵢ = unwrapva(bfᵢ)
Expand Down
Loading

8 comments on commit 9650510

@aviatesk
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here.

@aviatesk
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("inference", vs="@1d7b036b7318215b418a369dff700ea2a406ebec")

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here.

@aviatesk
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("inference", vs="@c3d0d67ac424c889fd7f24552a557c3a7ea8f1e6")

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here.

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The package evaluation job you requested has completed - possible new issues were detected.
The full report is available.

Please sign in to comment.