diff --git a/test/runtests.jl b/test/runtests.jl index e73d5ffca..27d013a61 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,6 +47,15 @@ if GROUP == "All" || GROUP == "Core" @safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end end +if GROUP == "All" || GROUP == "Core" || GROUP == "SymbolicIndexingInterface" + @safetestset "SymbolicIndexingInterface Trait Test" begin + include("symbolic_indexing_interface_trait.jl") + end + @safetestset "SymbolicIndexingInterface Parameter Indexing Test" begin + include("symbolic_indexing_interface_parameter_indexing.jl") + end +end + if GROUP == "Downstream" activate_downstream_env() #@time @safetestset "ParameterizedFunctions MATLABDiffEq Regression Test" begin include("downstream/ParameterizedFunctions_MATLAB.jl") end diff --git a/test/symbolic_indexing_interface_parameter_indexing.jl b/test/symbolic_indexing_interface_parameter_indexing.jl new file mode 100644 index 000000000..a05831484 --- /dev/null +++ b/test/symbolic_indexing_interface_parameter_indexing.jl @@ -0,0 +1,23 @@ +using SymbolicIndexingInterface +using Symbolics + +struct FakeIntegrator{P} + p::P +end + +SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys +SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p + +@variables a[1:2] b +sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t]) +p = [1.0, 2.0, 3.0] +fi = FakeIntegrator(copy(p)) +for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))] + get = getp(sys, sym) + set! = setp(sys, sym) + true_value = i isa Tuple ? getindex.((p,), i) : p[i] + @test get(fi) == true_value + set!(fi, 0.5 .* i) + @test get(fi) == 0.5 .* i + set!(fi, true_value) +end diff --git a/test/symbolic_indexing_interface_trait.jl b/test/symbolic_indexing_interface_trait.jl new file mode 100644 index 000000000..52d1579ae --- /dev/null +++ b/test/symbolic_indexing_interface_trait.jl @@ -0,0 +1,12 @@ +using Symbolics +using SymbolicUtils +using SymbolicIndexingInterface + +@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num]) .== + (ScalarSymbolic(),)) +@test symbolic_type(Symbolics.Arr) == ArraySymbolic() +@variables x +@test symbolic_type(x) == ScalarSymbolic() +@variables y[1:3] +@test symbolic_type(y) == ArraySymbolic() +@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))