From 22ff50f0704fb9f29513e35a4618d55ce4f3cd98 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 11 Mar 2024 15:10:54 -0400 Subject: [PATCH 1/3] Struct v2 --- src/struct.jl | 130 ++++++++++--------------------------------------- test/struct.jl | 66 ++++++------------------- 2 files changed, 43 insertions(+), 153 deletions(-) diff --git a/src/struct.jl b/src/struct.jl index 8cfc61889..ac70cc59b 100644 --- a/src/struct.jl +++ b/src/struct.jl @@ -1,123 +1,47 @@ -const TypeT = UInt32 -const ISINTEGER = TypeT(0) -const SIGNED_OFFSET = TypeT(1) -const SIZE_OFFSET = TypeT(2) - -const EMPTY_DIMS = Int[] - -struct StructElement - typ::TypeT - name::Symbol - size::Vector{Int} - function StructElement(::Type{T}, name, size = EMPTY_DIMS) where {T} - c = encodetyp(T) - c == typemax(TypeT) && error("Cannot handle type $T") - new(c, name, size) - end -end - -_sizeofrepr(typ::TypeT) = typ >> SIZE_OFFSET -sizeofrepr(s::StructElement) = _sizeofrepr(s.typ) -Base.size(s::StructElement) = s.size -Base.length(s::StructElement) = prod(size(s)) -Base.nameof(s::StructElement) = s.name -function Base.show(io::IO, s::StructElement) - print(io, nameof(s), "::", decodetyp(s.typ)) - if length(s) > 1 - print(io, "::(", join(size(s), " × "), ")") - end -end - -function encodetyp(::Type{T}) where {T} - typ = zero(UInt32) - if T <: Integer - typ |= TypeT(1) << ISINTEGER - if T <: Signed - typ |= TypeT(1) << SIGNED_OFFSET - elseif !(T <: Unsigned) - return typemax(TypeT) - end - elseif !(T <: AbstractFloat) - return typemax(TypeT) - end - typ |= TypeT(sizeof(T)) << SIZE_OFFSET -end - -function decodetyp(typ::TypeT) - _size = TypeT(8) * (typ >> SIZE_OFFSET) - if !iszero(typ & (TypeT(1) << ISINTEGER)) - if !iszero(typ & TypeT(1) << SIGNED_OFFSET) - _size == 8 ? Int8 : - _size == 16 ? Int16 : - _size == 32 ? Int32 : - _size == 64 ? Int64 : - error("invalid type $(typ)!") - else # unsigned - _size == 8 ? UInt8 : - _size == 16 ? UInt16 : - _size == 32 ? UInt32 : - _size == 64 ? UInt64 : - error("invalid type $(typ)!") - end - else # float - _size == 16 ? Float16 : - _size == 32 ? Float32 : - _size == 64 ? Float64 : - error("invalid type $(typ)!") - end -end - -struct Struct <: Real - juliatype::DataType - v::Vector{StructElement} -end - -function Base.hash(x::Struct, seed::UInt) - h1 = hash(juliatype(x), seed) - h2 = foldr(hash, getelements(x), init = h1) - h2 ⊻ (0x0e39036b7de2101a % UInt) +struct Struct{T} <: Real end """ symstruct(T) -Create a symbolic struct from a given type `T`. +Create a symbolic wrapper for struct from a given struct `T`. """ -function symstruct(T) - elems = map(fieldnames(T)) do fieldname - StructElement(fieldtype(T, fieldname), fieldname) - end |> collect - Struct(T, elems) +Struct(::Type{T}) where T = Struct{T} + +function Base.hash(x::Struct{T}, seed::UInt) where T + h1 = hash(T, seed) + h2 ⊻ (0x0e39036b7de2101a % UInt) end """ - juliatype(s::Struct) + juliatype(s::Type{<:Struct}) Get the Julia type that `s` is representing. """ -juliatype(s::Struct) = getfield(s, :juliatype) -getelements(s::Struct) = getfield(s, :v) +juliatype(::Type{Struct{T}}) where T = T +getelements(s::Type{<:Struct}) = fieldnames(juliatype(s)) +getelementtypes(s::Type{<:Struct}) = fieldtypes(juliatype(s)) -function Base.getproperty(s::Struct, name::Symbol) - v = getfield(s, :v) - idx = findfirst(x -> nameof(x) == name, v) - idx === nothing && error("no field $name in struct") - SymbolicUtils.term(getfield, s, idx, type = Real) +function symbolic_getproperty(ss, name::Symbol) + s = symtype(ss) + idx = findfirst(isequal(name), getelements(s)) + idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!") + T = getelementtypes(s)[idx] + SymbolicUtils.term(getfield, ss, Meta.quot(name), type = T) end - -function Base.setproperty!(s::Struct, name::Symbol, x) - v = getfield(s, :v) - idx = findfirst(x -> nameof(x) == name, v) - idx === nothing && error("no field $name in struct") - type = SymbolicUtils.symtype(x) - SymbolicUtils.term(setfield!, s, idx, x; type) +function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol) + wrap(symbolic_getproperty(unwrap(s), name)) end -function symbolic_getproperty(s, name::Symbol) - SymbolicUtils.term(getfield, s, Meta.quot(name), type = Real) +function symbolic_setproperty!(ss, name::Symbol, val) + s = symtype(ss) + idx = findfirst(isequal(name), getelements(s)) + idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!") + T = getelementtypes(s)[idx] + SymbolicUtils.term(setfield!, ss, Meta.quot(name), val, type = T) end -function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol) - wrap(symbolic_getproperty(unwrap(s), name)) +function symbolic_setproperty!(s::Union{Arr, Num}, name::Symbol, val) + wrap(symbolic_setproperty!(unwrap(s), name, val)) end # We cannot precisely derive the type after `getfield` due to SU limitations, diff --git a/test/struct.jl b/test/struct.jl index c9b178bc6..0ae1ea54e 100644 --- a/test/struct.jl +++ b/test/struct.jl @@ -1,57 +1,23 @@ using Test, Symbolics -using Symbolics: StructElement, Struct, operation, arguments, symstruct, juliatype - -handledtypes = [Int8, - Int16, - Int32, - Int64, - UInt8, - UInt16, - UInt32, - UInt64, - Float16, - Float32, - Float64] -for t in handledtypes - @test Symbolics.decodetyp(Symbolics.encodetyp(t)) === t -end - -@variables t x(t) -struct Fisk - a::Int8 - b::Int -end -a = StructElement(Int8, :a) -b = StructElement(Int, :b) -for s in [Struct(Fisk, [a, b]), symstruct(Fisk)] - sa = s.a - sb = s.b - @test operation(sa) === getfield - @test arguments(sa) == Any[s, 1] - @test arguments(sa) isa Any - @test operation(sb) === getfield - @test arguments(sb) == Any[s, 2] - @test arguments(sb) isa Any - @test juliatype(s) == Fisk -end - -s = Struct(Fisk, [a, b]) - -sa1 = (setproperty!(s, :a, UInt8(1))) -@test operation(sa1) === setfield! -@test arguments(sa1) == Any[s, 1, UInt8(1)] -@test arguments(sa1) isa Any - -sb1 = (setproperty!(s, :b, "hi")) -@test operation(sb1) === setfield! -@test arguments(sb1) == Any[s, 2, "hi"] -@test arguments(sb1) isa Any +using Symbolics: Struct, juliatype, symbolic_getproperty, symbolic_setproperty! struct Jörgen a::Int b::Float64 end -ss = symstruct(Jörgen) - -@test getfield(ss, :v) == [StructElement(Int, :a), StructElement(Float64, :b)] +S = Struct(Jörgen) +@variables x::S +xa = Symbolics.unwrap(symbolic_getproperty(x, :a)) +@test Symbolics.symtype(xa) == Int +@test Symbolics.operation(xa) == getfield +@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a)]) +xa = Symbolics.unwrap(symbolic_setproperty!(x, :a, 10)) +@test Symbolics.operation(xa) == setfield! +@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a), 10]) +@test Symbolics.symtype(xa) == Int + +xb = Symbolics.unwrap(symbolic_setproperty!(x, :b, 10)) +@test Symbolics.operation(xb) == setfield! +@test isequal(Symbolics.arguments(xb), [Symbolics.unwrap(x), Meta.quot(:b), 10]) +@test Symbolics.symtype(xb) == Float64 From acec4bc72fc53b8a34f78abdb896226e15a4b306 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 11 Mar 2024 15:37:30 -0400 Subject: [PATCH 2/3] Add `symbolic_constructor` --- src/struct.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/struct.jl b/src/struct.jl index ac70cc59b..421a10968 100644 --- a/src/struct.jl +++ b/src/struct.jl @@ -6,7 +6,8 @@ end Create a symbolic wrapper for struct from a given struct `T`. """ -Struct(::Type{T}) where T = Struct{T} +symstruct(::Type{T}) where T = Struct{T} +Struct{T}(vals...) where T = T(vals...) function Base.hash(x::Struct{T}, seed::UInt) where T h1 = hash(T, seed) @@ -44,7 +45,15 @@ function symbolic_setproperty!(s::Union{Arr, Num}, name::Symbol, val) wrap(symbolic_setproperty!(unwrap(s), name, val)) end +function symbolic_constructor(s::Type{<:Struct}, vals...) + N = length(getelements(s)) + N′ = length(vals) + N′ == N || error("$(juliatype(s)) needs $N field. Got $N′ fields!") + SymbolicUtils.term(s, vals..., type = s) +end + # We cannot precisely derive the type after `getfield` due to SU limitations, # so give up and just say Real. SymbolicUtils.promote_symtype(::typeof(getfield), ::Type{<:Struct}, _...) = Real SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T +SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T = s From db932e43d10b52debfebc003c52fe09d2ad36ec6 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 11 Mar 2024 15:39:54 -0400 Subject: [PATCH 3/3] Add a constructor test --- test/struct.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/struct.jl b/test/struct.jl index 0ae1ea54e..3e842932c 100644 --- a/test/struct.jl +++ b/test/struct.jl @@ -1,12 +1,12 @@ using Test, Symbolics -using Symbolics: Struct, juliatype, symbolic_getproperty, symbolic_setproperty! +using Symbolics: symstruct, juliatype, symbolic_getproperty, symbolic_setproperty!, symbolic_constructor struct Jörgen a::Int b::Float64 end -S = Struct(Jörgen) +S = symstruct(Jörgen) @variables x::S xa = Symbolics.unwrap(symbolic_getproperty(x, :a)) @test Symbolics.symtype(xa) == Int @@ -21,3 +21,6 @@ xb = Symbolics.unwrap(symbolic_setproperty!(x, :b, 10)) @test Symbolics.operation(xb) == setfield! @test isequal(Symbolics.arguments(xb), [Symbolics.unwrap(x), Meta.quot(:b), 10]) @test Symbolics.symtype(xb) == Float64 + +s = Symbolics.symbolic_constructor(S, 1, 1.0) +@test Symbolics.symtype(s) == S