Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Struct v2 #1088

Merged
merged 3 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 36 additions & 103 deletions src/struct.jl
Original file line number Diff line number Diff line change
@@ -1,126 +1,59 @@
const TypeT = UInt32
const ISINTEGER = TypeT(0)
const SIGNED_OFFSET = TypeT(1)
const SIZE_OFFSET = TypeT(2)

const EMPTY_DIMS = Int[]

struct StructElement
typ::TypeT
name::Symbol
size::Vector{Int}
function StructElement(::Type{T}, name, size = EMPTY_DIMS) where {T}
c = encodetyp(T)
c == typemax(TypeT) && error("Cannot handle type $T")
new(c, name, size)
end
end

_sizeofrepr(typ::TypeT) = typ >> SIZE_OFFSET
sizeofrepr(s::StructElement) = _sizeofrepr(s.typ)
Base.size(s::StructElement) = s.size
Base.length(s::StructElement) = prod(size(s))
Base.nameof(s::StructElement) = s.name
function Base.show(io::IO, s::StructElement)
print(io, nameof(s), "::", decodetyp(s.typ))
if length(s) > 1
print(io, "::(", join(size(s), " × "), ")")
end
end

function encodetyp(::Type{T}) where {T}
typ = zero(UInt32)
if T <: Integer
typ |= TypeT(1) << ISINTEGER
if T <: Signed
typ |= TypeT(1) << SIGNED_OFFSET
elseif !(T <: Unsigned)
return typemax(TypeT)
end
elseif !(T <: AbstractFloat)
return typemax(TypeT)
end
typ |= TypeT(sizeof(T)) << SIZE_OFFSET
end

function decodetyp(typ::TypeT)
_size = TypeT(8) * (typ >> SIZE_OFFSET)
if !iszero(typ & (TypeT(1) << ISINTEGER))
if !iszero(typ & TypeT(1) << SIGNED_OFFSET)
_size == 8 ? Int8 :
_size == 16 ? Int16 :
_size == 32 ? Int32 :
_size == 64 ? Int64 :
error("invalid type $(typ)!")
else # unsigned
_size == 8 ? UInt8 :
_size == 16 ? UInt16 :
_size == 32 ? UInt32 :
_size == 64 ? UInt64 :
error("invalid type $(typ)!")
end
else # float
_size == 16 ? Float16 :
_size == 32 ? Float32 :
_size == 64 ? Float64 :
error("invalid type $(typ)!")
end
end

struct Struct <: Real
juliatype::DataType
v::Vector{StructElement}
end

function Base.hash(x::Struct, seed::UInt)
h1 = hash(juliatype(x), seed)
h2 = foldr(hash, getelements(x), init = h1)
h2 (0x0e39036b7de2101a % UInt)
struct Struct{T} <: Real
end

"""
symstruct(T)
Create a symbolic struct from a given type `T`.
Create a symbolic wrapper for struct from a given struct `T`.
"""
function symstruct(T)
elems = map(fieldnames(T)) do fieldname
StructElement(fieldtype(T, fieldname), fieldname)
end |> collect
Struct(T, elems)
symstruct(::Type{T}) where T = Struct{T}
Struct{T}(vals...) where T = T(vals...)

function Base.hash(x::Struct{T}, seed::UInt) where T
h1 = hash(T, seed)
h2 (0x0e39036b7de2101a % UInt)
end

"""
juliatype(s::Struct)
juliatype(s::Type{<:Struct})
Get the Julia type that `s` is representing.
"""
juliatype(s::Struct) = getfield(s, :juliatype)
getelements(s::Struct) = getfield(s, :v)
juliatype(::Type{Struct{T}}) where T = T
getelements(s::Type{<:Struct}) = fieldnames(juliatype(s))
getelementtypes(s::Type{<:Struct}) = fieldtypes(juliatype(s))

function Base.getproperty(s::Struct, name::Symbol)
v = getfield(s, :v)
idx = findfirst(x -> nameof(x) == name, v)
idx === nothing && error("no field $name in struct")
SymbolicUtils.term(getfield, s, idx, type = Real)
function symbolic_getproperty(ss, name::Symbol)
s = symtype(ss)
idx = findfirst(isequal(name), getelements(s))
idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!")
T = getelementtypes(s)[idx]
SymbolicUtils.term(getfield, ss, Meta.quot(name), type = T)
end

function Base.setproperty!(s::Struct, name::Symbol, x)
v = getfield(s, :v)
idx = findfirst(x -> nameof(x) == name, v)
idx === nothing && error("no field $name in struct")
type = SymbolicUtils.symtype(x)
SymbolicUtils.term(setfield!, s, idx, x; type)
function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol)
wrap(symbolic_getproperty(unwrap(s), name))
end

function symbolic_getproperty(s, name::Symbol)
SymbolicUtils.term(getfield, s, Meta.quot(name), type = Real)
function symbolic_setproperty!(ss, name::Symbol, val)
s = symtype(ss)
idx = findfirst(isequal(name), getelements(s))
idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!")
T = getelementtypes(s)[idx]
SymbolicUtils.term(setfield!, ss, Meta.quot(name), val, type = T)
end
function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol)
wrap(symbolic_getproperty(unwrap(s), name))
function symbolic_setproperty!(s::Union{Arr, Num}, name::Symbol, val)
wrap(symbolic_setproperty!(unwrap(s), name, val))
end

function symbolic_constructor(s::Type{<:Struct}, vals...)
N = length(getelements(s))
N′ = length(vals)
N′ == N || error("$(juliatype(s)) needs $N field. Got $N′ fields!")
SymbolicUtils.term(s, vals..., type = s)
end

# We cannot precisely derive the type after `getfield` due to SU limitations,
# so give up and just say Real.
SymbolicUtils.promote_symtype(::typeof(getfield), ::Type{<:Struct}, _...) = Real
SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T
SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T = s
69 changes: 19 additions & 50 deletions test/struct.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,26 @@
using Test, Symbolics
using Symbolics: StructElement, Struct, operation, arguments, symstruct, juliatype

handledtypes = [Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64]
for t in handledtypes
@test Symbolics.decodetyp(Symbolics.encodetyp(t)) === t
end

@variables t x(t)
struct Fisk
a::Int8
b::Int
end
a = StructElement(Int8, :a)
b = StructElement(Int, :b)
for s in [Struct(Fisk, [a, b]), symstruct(Fisk)]
sa = s.a
sb = s.b
@test operation(sa) === getfield
@test arguments(sa) == Any[s, 1]
@test arguments(sa) isa Any
@test operation(sb) === getfield
@test arguments(sb) == Any[s, 2]
@test arguments(sb) isa Any
@test juliatype(s) == Fisk
end

s = Struct(Fisk, [a, b])

sa1 = (setproperty!(s, :a, UInt8(1)))
@test operation(sa1) === setfield!
@test arguments(sa1) == Any[s, 1, UInt8(1)]
@test arguments(sa1) isa Any

sb1 = (setproperty!(s, :b, "hi"))
@test operation(sb1) === setfield!
@test arguments(sb1) == Any[s, 2, "hi"]
@test arguments(sb1) isa Any
using Symbolics: symstruct, juliatype, symbolic_getproperty, symbolic_setproperty!, symbolic_constructor

struct Jörgen
a::Int
b::Float64
end

ss = symstruct(Jörgen)

@test getfield(ss, :v) == [StructElement(Int, :a), StructElement(Float64, :b)]
S = symstruct(Jörgen)
@variables x::S
xa = Symbolics.unwrap(symbolic_getproperty(x, :a))
@test Symbolics.symtype(xa) == Int
@test Symbolics.operation(xa) == getfield
@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a)])
xa = Symbolics.unwrap(symbolic_setproperty!(x, :a, 10))
@test Symbolics.operation(xa) == setfield!
@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a), 10])
@test Symbolics.symtype(xa) == Int

xb = Symbolics.unwrap(symbolic_setproperty!(x, :b, 10))
@test Symbolics.operation(xb) == setfield!
@test isequal(Symbolics.arguments(xb), [Symbolics.unwrap(x), Meta.quot(:b), 10])
@test Symbolics.symtype(xb) == Float64

s = Symbolics.symbolic_constructor(S, 1, 1.0)
@test Symbolics.symtype(s) == S
Loading