Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: model partially initialized structs with PartialStruct #55297

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 58 additions & 28 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2006,33 +2006,64 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
return Conditional(aty.slot, thentype, elsetype)
end
elseif f === isdefined
uty = argtypes[2]
a = ssa_def_slot(fargs[2], sv)
if isa(uty, Union) && isa(a, SlotNumber)
fld = argtypes[3]
thentype = Bottom
elsetype = Bottom
for ty in uniontypes(uty)
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
if isa(cnd, Const)
if cnd.val::Bool
thentype = thentype ⊔ ty
if isa(a, SlotNumber)
argtype2 = argtypes[2]
if isa(argtype2, Union)
fld = argtypes[3]
thentype = Bottom
elsetype = Bottom
for ty in uniontypes(argtype2)
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
if isa(cnd, Const)
if cnd.val::Bool
thentype = thentype ⊔ ty
else
elsetype = elsetype ⊔ ty
end
else
thentype = thentype ⊔ ty
elsetype = elsetype ⊔ ty
end
else
thentype = thentype ⊔ ty
elsetype = elsetype ⊔ ty
end
return Conditional(a, thentype, elsetype)
else
thentype = form_partially_defined_struct(argtype2, argtypes[3])
if thentype !== nothing
elsetype = argtype2
if rt === Const(false)
thentype = Bottom
elseif rt === Const(true)
elsetype = Bottom
end
return Conditional(a, thentype, elsetype)
end
end
return Conditional(a, thentype, elsetype)
end
end
end
@assert !isa(rt, TypeVar) "unhandled TypeVar"
return rt
end

function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name))
obj isa Const && return nothing # nothing to refine
name isa Const || return nothing
objt0 = widenconst(obj)
objt = unwrap_unionall(objt0)
objt isa DataType || return nothing
isabstracttype(objt) && return nothing
fldidx = try_compute_fieldidx(objt, name.val)
fldidx === nothing && return nothing
nminfld = datatype_min_ninitialized(objt)
if ismutabletype(objt)
fldidx == nminfld+1 || return nothing
else
fldidx > nminfld || return nothing
end
return PartialStruct(objt0, Any[fieldtype(objt0, i) for i = 1:fldidx])
end

function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta)
na = length(argtypes)
if isvarargtype(argtypes[end])
Expand Down Expand Up @@ -2573,20 +2604,18 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V
end
ats[i] = at
end
# For now, don't allow:
# - Const/PartialStruct of mutables (but still allow PartialStruct of mutables
# with `const` fields if anything refined)
# - partially initialized Const/PartialStruct
if fcount == nargs
if consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
elseif anyrefine
rt = PartialStruct(rt, ats)
if fcount == nargs && consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
elseif anyrefine || nargs > datatype_min_ninitialized(rt)
# propagate partially initialized struct as `PartialStruct` when:
# - any refinement information is available (`anyrefine`), or when
# - `nargs` is greater than `n_initialized` derived from the struct type
# information alone
rt = PartialStruct(rt, ats)
end
else
rt = refine_partial_type(rt)
Expand Down Expand Up @@ -3094,7 +3123,8 @@ end
@nospecializeinfer function widenreturn_partials(𝕃ᵢ::PartialsLattice, @nospecialize(rt), info::BestguessInfo)
if isa(rt, PartialStruct)
fields = copy(rt.fields)
local anyrefine = false
anyrefine = !isvarargtype(rt.fields[end]) &&
length(rt.fields) > datatype_min_ninitialized(unwrap_unionall(rt.typ))
𝕃 = typeinf_lattice(info.interp)
⊏ = strictpartialorder(𝕃)
for i in 1:length(fields)
Expand Down
7 changes: 6 additions & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,12 @@ struct IntermediaryCollector <: WalkerCallback
intermediaries::SPCSet
end
function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id)
if !(def isa Expr)
push!(walker_callback.intermediaries, defssa.id)
if def isa PiNode
return LiftedValue(def.val)
end
end
return nothing
end

Expand Down
42 changes: 30 additions & 12 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,23 +419,29 @@ end
else
return Bottom
end
if 1 <= idx <= datatype_min_ninitialized(a1)
if 1 idx datatype_min_ninitialized(a1)
return Const(true)
elseif a1.name === _NAMEDTUPLE_NAME
if isconcretetype(a1)
return Const(false)
else
ns = a1.parameters[1]
if isa(ns, Tuple)
return Const(1 <= idx <= length(ns))
return Const(1 idx length(ns))
end
end
elseif idx <= 0 || (!isvatuple(a1) && idx > fieldcount(a1))
elseif idx 0 || (!isvatuple(a1) && idx > fieldcount(a1))
return Const(false)
elseif isa(arg1, Const)
if !ismutabletype(a1) || isconst(a1, idx)
return Const(isdefined(arg1.val, idx))
end
elseif isa(arg1, PartialStruct)
if !isvarargtype(arg1.fields[end])
if 1 ≤ idx ≤ length(arg1.fields)
return Const(true)
end
end
elseif !isvatuple(a1)
fieldT = fieldtype(a1, idx)
if isa(fieldT, DataType) && isbitstype(fieldT)
Expand Down Expand Up @@ -989,27 +995,39 @@ end
⊑ = partialorder(𝕃)

# If we have s00 being a const, we can potentially refine our type-based analysis above
if isa(s00, Const) || isconstType(s00)
if !isa(s00, Const)
sv = (s00::DataType).parameters[1]
else
if isa(s00, Const) || isconstType(s00) || isa(s00, PartialStruct)
if isa(s00, Const)
sv = s00.val
sty = typeof(sv)
nflds = nfields(sv)
ismod = sv isa Module
elseif isa(s00, PartialStruct)
sty = unwrap_unionall(s00.typ)
nflds = fieldcount_noerror(sty)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to impose additional semantics on PartialStruct that are not clear to me are satisfied with the existing lattice elements. I'm thinking of cases like:

struct Foo
    x::Int
    y::Any
    Foo(x) = rand() ? new(1) : new(1, x)
end

It's not clear to me that the tmerge will respect the semantics you're imposing here. Perhaps it would be cleaner to give PartialStruct a new ninitialized field?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The information equivalent to ninitialized can be derived from the length of fields, and when PartialStructs with different lengths of fields are joined, the joined PartialStructs are simply widened to the simple object type, so I don't see any problem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that loses some precision, but I guess that's fine. That said, this needs a comment somewhere to describe the semantics of a PartialStruct with short fields. We should also audit all uses of fields and probably pkgeval this. Lastly, I'm not sure this logic is correct, because it doesn't look like it's actually looking at the length of fields, but I don't have the time to check right now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, this needs a comment somewhere to describe the semantics of a PartialStruct with short fields

Does c57d34c seem to address this?

We should also audit all uses of fields and probably pkgeval this.

Yeah, I did audit all uses, but it would be a good idea to exercise with pkgeval.
@nanosoldier runtests()

I'm not sure this logic is correct, because it doesn't look like it's actually looking at the length of fields, but I don't have the time to check right now.

Specifically regarding this logic, the length of fields is later referenced within isdefined_tfunc. fieldcount_noerror is used for iteration when applying isdefined_tfunc to all fields that might be uninitialized. However, that code path is hit only when bounds checking turned off and the field is not constant, making it quite niche.

ismod = false
else
sv = (s00::DataType).parameters[1]
sty = typeof(sv)
nflds = nfields(sv)
ismod = sv isa Module
end
if isa(name, Const)
nval = name.val
if !isa(nval, Symbol)
isa(sv, Module) && return false
ismod && return false
isa(nval, Int) || return false
end
return isdefined_tfunc(𝕃, s00, name) === Const(true)
end
boundscheck && return false

# If bounds checking is disabled and all fields are assigned,
# we may assume that we don't throw
isa(sv, Module) && return false
@assert !boundscheck
ismod && return false
name ⊑ Int || name ⊑ Symbol || return false
typeof(sv).name.n_uninitialized == 0 && return true
for i = (datatype_min_ninitialized(typeof(sv)) + 1):nfields(sv)
sty.name.n_uninitialized == 0 && return true
nflds === nothing && return false
for i = (datatype_min_ninitialized(sty)+1):nflds
isdefined_tfunc(𝕃, s00, Const(i)) === Const(true) || return false
end
return true
Expand Down
87 changes: 60 additions & 27 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,42 @@

# N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used
# inside the global code cache.
#
# # The type of a value might be constant
# struct Const
# val
# end
#
# struct PartialStruct
# typ
# fields::Vector{Any} # elements are other type lattice members
# end

import Core: Const, PartialStruct

"""
struct Const
val
end

The type representing a constant value.
"""
:(Const)

"""
struct PartialStruct
typ
fields::Vector{Any} # elements are other type lattice members
end

This extended lattice element is introduced when we have information about an object's
fields beyond what can be obtained from the object type. E.g. it represents a tuple where
some elements are known to be constants or a struct whose `Any`-typed field is initialized
with `Int` values.

- `typ` indicates the type of the object
- `fields` holds the lattice elements corresponding to each field of the object

If `typ` is a struct, `fields` represents the fields of the struct that are guaranteed to be
initialized. For instance, if the length of `fields` of `PartialStruct` representing a
struct with 4 fields is 3, the 4th field may not be initialized. If the length is 4, all
fields are guaranteed to be initialized.

If `typ` is a tuple, the last element of `fields` may be `Vararg`. In this case, it is
guaranteed that the number of elements in the tuple is at least `length(fields)-1`, but the
exact number of elements is unknown.
"""
:(PartialStruct)
function PartialStruct(@nospecialize(typ), fields::Vector{Any})
for i = 1:length(fields)
assert_nested_slotwrapper(fields[i])
Expand Down Expand Up @@ -57,23 +82,20 @@ end
Conditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) =
Conditional(slot_id(var), thentype, elsetype)

import Core: InterConditional
"""
cnd::InterConditional
struct InterConditional
slot::Int
thentype
elsetype
end

Similar to `Conditional`, but conveys inter-procedural constraints imposed on call arguments.
This is separate from `Conditional` to catch logic errors: the lattice element name is `InterConditional`
while processing a call, then `Conditional` everywhere else. Thus `InterConditional` does not appear in
`CompilerTypes`—these type's usages are disjoint—though we define the lattice for `InterConditional`.
"""
:(InterConditional)
import Core: InterConditional
# struct InterConditional
# slot::Int
# thentype
# elsetype
# InterConditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype)) =
# new(slot, thentype, elsetype)
# end
InterConditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) =
InterConditional(slot_id(var), thentype, elsetype)

Expand Down Expand Up @@ -447,8 +469,13 @@ end
@nospecializeinfer function ⊑(lattice::PartialsLattice, @nospecialize(a), @nospecialize(b))
if isa(a, PartialStruct)
if isa(b, PartialStruct)
if !(length(a.fields) == length(b.fields) && a.typ <: b.typ)
return false
a.typ <: b.typ || return false
if length(a.fields) ≠ length(b.fields)
if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end]))
length(a.fields) ≥ length(b.fields) || return false
else
return false
end
end
for i in 1:length(b.fields)
af = a.fields[i]
Expand All @@ -471,19 +498,25 @@ end
return isa(b, Type) && a.typ <: b
elseif isa(b, PartialStruct)
if isa(a, Const)
nf = nfields(a.val)
nf == length(b.fields) || return false
widea = widenconst(a)::DataType
wideb = widenconst(b)
wideb′ = unwrap_unionall(wideb)::DataType
widea.name === wideb′.name || return false
# We can skip the subtype check if b is a Tuple, since in that
# case, the ⊑ of the elements is sufficient.
if wideb′.name !== Tuple.name && !(widea <: wideb)
return false
if wideb′.name === Tuple.name
# We can skip the subtype check if b is a Tuple, since in that
# case, the ⊑ of the elements is sufficient.
# But for tuple comparisons, we need their lengths to be the same for now.
# TODO improve accuracy for cases when `b` contains vararg element
nfields(a.val) == length(b.fields) || return false
else
widea <: wideb || return false
# for structs we need to check that `a` has more information than `b` that may be partially initialized
n_initialized(a) ≥ length(b.fields) || return false
end
nf = nfields(a.val)
for i in 1:nf
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct
bfᵢ = b.fields[i]
if i == nf
bfᵢ = unwrapva(bfᵢ)
Expand Down
Loading