Skip to content

Commit

Permalink
Merge pull request #1088 from JuliaSymbolics/myb/struct2
Browse files Browse the repository at this point in the history
Struct v2
  • Loading branch information
YingboMa authored Mar 11, 2024
2 parents 97c7b13 + db932e4 commit ca9a95d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 153 deletions.
139 changes: 36 additions & 103 deletions src/struct.jl
Original file line number Diff line number Diff line change
@@ -1,126 +1,59 @@
const TypeT = UInt32
const ISINTEGER = TypeT(0)
const SIGNED_OFFSET = TypeT(1)
const SIZE_OFFSET = TypeT(2)

const EMPTY_DIMS = Int[]

struct StructElement
typ::TypeT
name::Symbol
size::Vector{Int}
function StructElement(::Type{T}, name, size = EMPTY_DIMS) where {T}
c = encodetyp(T)
c == typemax(TypeT) && error("Cannot handle type $T")
new(c, name, size)
end
end

_sizeofrepr(typ::TypeT) = typ >> SIZE_OFFSET
sizeofrepr(s::StructElement) = _sizeofrepr(s.typ)
Base.size(s::StructElement) = s.size
Base.length(s::StructElement) = prod(size(s))
Base.nameof(s::StructElement) = s.name
function Base.show(io::IO, s::StructElement)
print(io, nameof(s), "::", decodetyp(s.typ))
if length(s) > 1
print(io, "::(", join(size(s), " × "), ")")
end
end

function encodetyp(::Type{T}) where {T}
typ = zero(UInt32)
if T <: Integer
typ |= TypeT(1) << ISINTEGER
if T <: Signed
typ |= TypeT(1) << SIGNED_OFFSET
elseif !(T <: Unsigned)
return typemax(TypeT)
end
elseif !(T <: AbstractFloat)
return typemax(TypeT)
end
typ |= TypeT(sizeof(T)) << SIZE_OFFSET
end

function decodetyp(typ::TypeT)
_size = TypeT(8) * (typ >> SIZE_OFFSET)
if !iszero(typ & (TypeT(1) << ISINTEGER))
if !iszero(typ & TypeT(1) << SIGNED_OFFSET)
_size == 8 ? Int8 :
_size == 16 ? Int16 :
_size == 32 ? Int32 :
_size == 64 ? Int64 :
error("invalid type $(typ)!")
else # unsigned
_size == 8 ? UInt8 :
_size == 16 ? UInt16 :
_size == 32 ? UInt32 :
_size == 64 ? UInt64 :
error("invalid type $(typ)!")
end
else # float
_size == 16 ? Float16 :
_size == 32 ? Float32 :
_size == 64 ? Float64 :
error("invalid type $(typ)!")
end
end

struct Struct <: Real
juliatype::DataType
v::Vector{StructElement}
end

function Base.hash(x::Struct, seed::UInt)
h1 = hash(juliatype(x), seed)
h2 = foldr(hash, getelements(x), init = h1)
h2 (0x0e39036b7de2101a % UInt)
struct Struct{T} <: Real
end

"""
symstruct(T)
Create a symbolic struct from a given type `T`.
Create a symbolic wrapper for struct from a given struct `T`.
"""
function symstruct(T)
elems = map(fieldnames(T)) do fieldname
StructElement(fieldtype(T, fieldname), fieldname)
end |> collect
Struct(T, elems)
symstruct(::Type{T}) where T = Struct{T}
Struct{T}(vals...) where T = T(vals...)

function Base.hash(x::Struct{T}, seed::UInt) where T
h1 = hash(T, seed)
h2 (0x0e39036b7de2101a % UInt)
end

"""
juliatype(s::Struct)
juliatype(s::Type{<:Struct})
Get the Julia type that `s` is representing.
"""
juliatype(s::Struct) = getfield(s, :juliatype)
getelements(s::Struct) = getfield(s, :v)
juliatype(::Type{Struct{T}}) where T = T
getelements(s::Type{<:Struct}) = fieldnames(juliatype(s))
getelementtypes(s::Type{<:Struct}) = fieldtypes(juliatype(s))

function Base.getproperty(s::Struct, name::Symbol)
v = getfield(s, :v)
idx = findfirst(x -> nameof(x) == name, v)
idx === nothing && error("no field $name in struct")
SymbolicUtils.term(getfield, s, idx, type = Real)
function symbolic_getproperty(ss, name::Symbol)
s = symtype(ss)
idx = findfirst(isequal(name), getelements(s))
idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!")
T = getelementtypes(s)[idx]
SymbolicUtils.term(getfield, ss, Meta.quot(name), type = T)
end

function Base.setproperty!(s::Struct, name::Symbol, x)
v = getfield(s, :v)
idx = findfirst(x -> nameof(x) == name, v)
idx === nothing && error("no field $name in struct")
type = SymbolicUtils.symtype(x)
SymbolicUtils.term(setfield!, s, idx, x; type)
function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol)
wrap(symbolic_getproperty(unwrap(s), name))
end

function symbolic_getproperty(s, name::Symbol)
SymbolicUtils.term(getfield, s, Meta.quot(name), type = Real)
function symbolic_setproperty!(ss, name::Symbol, val)
s = symtype(ss)
idx = findfirst(isequal(name), getelements(s))
idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!")
T = getelementtypes(s)[idx]
SymbolicUtils.term(setfield!, ss, Meta.quot(name), val, type = T)
end
function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol)
wrap(symbolic_getproperty(unwrap(s), name))
function symbolic_setproperty!(s::Union{Arr, Num}, name::Symbol, val)
wrap(symbolic_setproperty!(unwrap(s), name, val))
end

function symbolic_constructor(s::Type{<:Struct}, vals...)
N = length(getelements(s))
N′ = length(vals)
N′ == N || error("$(juliatype(s)) needs $N field. Got $N′ fields!")
SymbolicUtils.term(s, vals..., type = s)
end

# We cannot precisely derive the type after `getfield` due to SU limitations,
# so give up and just say Real.
SymbolicUtils.promote_symtype(::typeof(getfield), ::Type{<:Struct}, _...) = Real
SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T
SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T = s
69 changes: 19 additions & 50 deletions test/struct.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,26 @@
using Test, Symbolics
using Symbolics: StructElement, Struct, operation, arguments, symstruct, juliatype

handledtypes = [Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64]
for t in handledtypes
@test Symbolics.decodetyp(Symbolics.encodetyp(t)) === t
end

@variables t x(t)
struct Fisk
a::Int8
b::Int
end
a = StructElement(Int8, :a)
b = StructElement(Int, :b)
for s in [Struct(Fisk, [a, b]), symstruct(Fisk)]
sa = s.a
sb = s.b
@test operation(sa) === getfield
@test arguments(sa) == Any[s, 1]
@test arguments(sa) isa Any
@test operation(sb) === getfield
@test arguments(sb) == Any[s, 2]
@test arguments(sb) isa Any
@test juliatype(s) == Fisk
end

s = Struct(Fisk, [a, b])

sa1 = (setproperty!(s, :a, UInt8(1)))
@test operation(sa1) === setfield!
@test arguments(sa1) == Any[s, 1, UInt8(1)]
@test arguments(sa1) isa Any

sb1 = (setproperty!(s, :b, "hi"))
@test operation(sb1) === setfield!
@test arguments(sb1) == Any[s, 2, "hi"]
@test arguments(sb1) isa Any
using Symbolics: symstruct, juliatype, symbolic_getproperty, symbolic_setproperty!, symbolic_constructor

struct Jörgen
a::Int
b::Float64
end

ss = symstruct(Jörgen)

@test getfield(ss, :v) == [StructElement(Int, :a), StructElement(Float64, :b)]
S = symstruct(Jörgen)
@variables x::S
xa = Symbolics.unwrap(symbolic_getproperty(x, :a))
@test Symbolics.symtype(xa) == Int
@test Symbolics.operation(xa) == getfield
@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a)])
xa = Symbolics.unwrap(symbolic_setproperty!(x, :a, 10))
@test Symbolics.operation(xa) == setfield!
@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a), 10])
@test Symbolics.symtype(xa) == Int

xb = Symbolics.unwrap(symbolic_setproperty!(x, :b, 10))
@test Symbolics.operation(xb) == setfield!
@test isequal(Symbolics.arguments(xb), [Symbolics.unwrap(x), Meta.quot(:b), 10])
@test Symbolics.symtype(xb) == Float64

s = Symbolics.symbolic_constructor(S, 1, 1.0)
@test Symbolics.symtype(s) == S

0 comments on commit ca9a95d

Please sign in to comment.