Skip to content

Commit

Permalink
Merge pull request #1090 from AayushSabharwal/as/arr-isequal
Browse files Browse the repository at this point in the history
fix: fix substitute, isequal for Arr
  • Loading branch information
YingboMa authored Mar 14, 2024
2 parents 75cc676 + 92a8125 commit fc74b9a
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ for T in [Num, Complex{Num}]
end
end

for sType in [Pair, Vector, Dict]
@eval substitute(expr::Arr, s::$sType; kw...) = wrap(substituter(s)(unwrap(expr); kw...))
end

function symbolics_to_sympy end
export symbolics_to_sympy

Expand Down
2 changes: 2 additions & 0 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ end

Base.hash(x::Arr, u::UInt) = hash(unwrap(x), u)
Base.isequal(a::Arr, b::Arr) = isequal(unwrap(a), unwrap(b))
Base.isequal(a::Arr, b::Symbolic) = isequal(unwrap(a), b)
Base.isequal(a::Symbolic, b::Arr) = isequal(b, a)

ArrayOp(x::Arr) = unwrap(x)

Expand Down
2 changes: 1 addition & 1 deletion src/num.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ substitute(expr, s::Vector; kw...) = substituter(s)(expr; kw...)
substituter(pair::Pair) = substituter((pair,))
function substituter(pairs)
dict = Dict(value(k) => value(v) for (k, v) in pairs)
(expr; kw...) -> SymbolicUtils.substitute(expr, dict; kw...)
(expr; kw...) -> SymbolicUtils.substitute(value(expr), dict; kw...)
end

SymbolicUtils.symtype(n::Num) = symtype(value(n))
Expand Down
21 changes: 21 additions & 0 deletions test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,27 @@ end
@test isequal(collect(dtv), collect(A .* u .- u.^2 .* v .+ alpha .* lapv))
end

@testset "Unwrapped array equality" begin
@variables x[1:3]
ux = unwrap(x)
@test isequal(x, x)
@test isequal(x, ux)
@test isequal(ux, x)
end

@testset "Array expression substitution" begin
@variables x[1:3] p[1:3, 1:3]
bar(x, p) = p * x
@register_array_symbolic bar(x::AbstractVector, p::AbstractMatrix) begin
size = size(x)
eltype = promote_type(eltype(x), eltype(p))
end

@test isequal(substitute(bar(x, p), x => ones(3)), bar(ones(3), p))
@test isequal(substitute(bar(x, p), Dict(x => ones(3), p => ones(3, 3))), wrap(3ones(3)))
@test isequal(substitute(bar(x, p), [x => ones(3), p => ones(3, 3)]), wrap(3ones(3)))
end

@testset "Partial array substitution" begin
@variables x[1:3] A[1:2, 1:2, 1:2]

Expand Down

0 comments on commit fc74b9a

Please sign in to comment.