From 10219bf1138030f0a9078054b3aa1687c3e35c85 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 25 Jan 2024 10:39:29 -0500 Subject: [PATCH] Add struct tracing Co-authored-by: Fredrik Bagge Carlson --- src/Symbolics.jl | 13 +++---- src/struct.jl | 92 ++++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + test/struct.jl | 40 +++++++++++++++++++++ 4 files changed, 140 insertions(+), 6 deletions(-) create mode 100644 src/struct.jl create mode 100644 test/struct.jl diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 14a43a8d8..9481525ab 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -18,19 +18,19 @@ using PrecompileTools using Setfield import DomainSets: Domain - + import SymbolicUtils: similarterm, istree, operation, arguments, symtype, metadata - + import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic, FnType, @rule, Rewriters, substitute, promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv - + using SymbolicUtils.Code - + import SymbolicUtils.Rewriters: Chain, Prewalk, Postwalk, Fixpoint - + import SymbolicUtils.Code: toexpr - + import ArrayInterface using RuntimeGeneratedFunctions using SciMLBase, IfElse @@ -145,6 +145,7 @@ include("parsing.jl") export parse_expr_to_symbolic include("error_hints.jl") +include("struct.jl") # Hacks to make wrappers "nicer" const NumberTypes = Union{AbstractFloat,Integer,Complex{<:AbstractFloat},Complex{<:Integer}} diff --git a/src/struct.jl b/src/struct.jl new file mode 100644 index 000000000..e6ff592f9 --- /dev/null +++ b/src/struct.jl @@ -0,0 +1,92 @@ +const TypeT = UInt32 +const ISINTEGER = TypeT(0) +const SIGNED_OFFSET = TypeT(1) +const SIZE_OFFSET = TypeT(2) + +const EMPTY_DIMS = Int[] + +struct StructElement + typ::TypeT + name::Symbol + size::Vector{Int} + function StructElement(::Type{T}, name, size = EMPTY_DIMS) where {T} + c = encodetyp(T) + c == typemax(TypeT) && error("Cannot handle type $T") + new(c, name, size) + end +end + +_sizeofrepr(typ::TypeT) = typ >> SIZE_OFFSET +sizeofrepr(s::StructElement) = _sizeofrepr(s.typ) +Base.size(s::StructElement) = s.size +Base.length(s::StructElement) = prod(size(s)) +Base.nameof(s::StructElement) = s.name +function Base.show(io::IO, s::StructElement) + print(io, nameof(s), "::", decodetyp(s.typ)) + if length(s) > 1 + print(io, "::(", join(size(s), " × "), ")") + end +end + +function encodetyp(::Type{T}) where {T} + typ = zero(UInt32) + if T <: Integer + typ |= TypeT(1) << ISINTEGER + if T <: Signed + typ |= TypeT(1) << SIGNED_OFFSET + elseif !(T <: Unsigned) + return typemax(TypeT) + end + elseif !(T <: AbstractFloat) + return typemax(TypeT) + end + typ |= TypeT(sizeof(T)) << SIZE_OFFSET +end + +function decodetyp(typ::TypeT) + siz = TypeT(8) * (typ >> SIZE_OFFSET) + if !iszero(typ & (TypeT(1) << ISINTEGER)) + if !iszero(typ & TypeT(1) << SIGNED_OFFSET) + siz == 8 ? Int8 : + siz == 16 ? Int16 : + siz == 32 ? Int32 : + siz == 64 ? Int64 : + error("invalid type $(typ)!") + else # unsigned + siz == 8 ? UInt8 : + siz == 16 ? UInt16 : + siz == 32 ? UInt32 : + siz == 64 ? UInt64 : + error("invalid type $(typ)!") + end + else # float + siz == 16 ? Float16 : + siz == 32 ? Float32 : + siz == 64 ? Float64 : + error("invalid type $(typ)!") + end +end + +struct Struct + v::Vector{StructElement} +end + +function Base.getproperty(s::Struct, name::Symbol) + v = getfield(s, :v) + idx = findfirst(x -> nameof(x) == name, v) + idx === nothing && error("no field $name in struct") + SymbolicUtils.term(getfield, s, idx, type = Real) +end + +function Base.setproperty!(s::Struct, name::Symbol, x) + v = getfield(s, :v) + idx = findfirst(x -> nameof(x) == name, v) + idx === nothing && error("no field $name in struct") + type = SymbolicUtils.symtype(x) + SymbolicUtils.term(setfield!, s, idx, x; type) +end + +# We cannot precisely derive the type after `getfield` due to SU limitations, +# so give up and just say Real. +SymbolicUtils.promote_symtype(::typeof(getfield), ::Type{<:Struct}, _...) = Real +SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T diff --git a/test/runtests.jl b/test/runtests.jl index 7d42fb555..521d86cfc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,7 @@ limit(a, N) = a == N + 1 ? 1 : a == 0 ? N : a @register_symbolic limit(a, N)::Integer if GROUP == "All" || GROUP == "Core" + @safetestset "Struct Test" begin include("struct.jl") end @safetestset "Macro Test" begin include("macro.jl") end @safetestset "Arrays" begin include("arrays.jl") end @safetestset "View-setting" begin include("stencils.jl") end diff --git a/test/struct.jl b/test/struct.jl new file mode 100644 index 000000000..b509abc51 --- /dev/null +++ b/test/struct.jl @@ -0,0 +1,40 @@ +using Test, Symbolics +using Symbolics: StructElement, Struct, operation, arguments + +handledtypes = [Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float16, + Float32, + Float64] +for t in handledtypes + @test Symbolics.decodetyp(Symbolics.encodetyp(t)) === t +end + +@variables t x(t) +a = StructElement(Int8, :a) +b = StructElement(Int, :b) +s = Struct([a, b]) +sa = s.a +sb = s.b +@test operation(sa) === getfield +@test arguments(sa) == Any[s, 1] +@test arguments(sa) isa Any +@test operation(sb) === getfield +@test arguments(sb) == Any[s, 2] +@test arguments(sb) isa Any + +sa1 = (setproperty!(s, :a, UInt8(1))) +@test operation(sa1) === setfield! +@test arguments(sa1) == Any[s, 1, UInt8(1)] +@test arguments(sa1) isa Any + +sb1 = (setproperty!(s, :b, "hi")) +@test operation(sb1) === setfield! +@test arguments(sb1) == Any[s, 2, "hi"] +@test arguments(sb1) isa Any