From a4b1920a25bf4e18aadea6bb22a8bf1f800a4df9 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 25 Apr 2024 19:35:47 -0400 Subject: [PATCH] Fix getindex symtype --- src/array-lib.jl | 28 ++++++++++++++++++---------- test/overloads.jl | 4 ++-- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/array-lib.jl b/src/array-lib.jl index b0204341e..74f11e420 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -1,4 +1,5 @@ import Base: getindex +inner_unwrap(x) = x isa AbstractArray ? unwrap.(x) : x ##### getindex ##### struct GetindexPosthookCtx end @@ -167,10 +168,11 @@ isonedim(x, i) = shape(x) == Unknown() ? false : isone(size(x, i)) function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast}) # Do the thing here - ndim = mapfoldl(ndims, max, bc.args, init=0) + args = inner_unwrap.(bc.args) + ndim = mapfoldl(ndims, max, args, init=0) subscripts = makesubscripts(ndim) - onedim_count = mapreduce(+, bc.args) do x + onedim_count = mapreduce(+, args) do x if ndims(x) != 0 map(i -> isonedim(x, i) ? 1 : 0, 1:ndim) else @@ -178,9 +180,9 @@ function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast}) end end - extruded = map(x -> x < length(bc.args), onedim_count) + extruded = map(x -> x < length(args), onedim_count) - expr_args′ = map(bc.args) do x + expr_args′ = map(args) do x if ndims(x) != 0 subs = map(i -> extruded[i] && isonedim(x, i) ? 1 : subscripts[i], 1:ndims(x)) @@ -194,8 +196,8 @@ function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast}) expr = term(bc.f, expr_args′...) # Imagine x .=> y -- if you don't have a term # then you get pairs, and index matcher cannot # recurse into pairs - Atype = propagate_atype(broadcast, bc.f, bc.args...) - args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, bc.args) + Atype = propagate_atype(broadcast, bc.f, args...) + args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, args) ArrayOp(Atype{symtype(expr),ndim}, (subscripts...,), expr, @@ -261,17 +263,22 @@ end isadjointvec(A::ArrayOp) = isadjointvec(A.term) +__symtype(x::Type{<:Symbolic{T}}) where T = T +function symeltype(A) + T = eltype(A) + T <: Symbolic ? __symtype(T) : T +end # TODO: add more such methods function getindex(A::AbstractArray, i::Symbolic{<:Integer}, ii::Symbolic{<:Integer}...) - Term{eltype(A)}(getindex, [A, i, ii...]) + Term{symeltype(A)}(getindex, [A, i, ii...]) end function getindex(A::AbstractArray, i::Int, j::Symbolic{<:Integer}) - Term{eltype(A)}(getindex, [A, i, j]) + Term{symeltype(A)}(getindex, [A, i, j]) end function getindex(A::AbstractArray, j::Symbolic{<:Integer}, i::Int) - Term{eltype(A)}(getindex, [A, j, i]) + Term{symeltype(A)}(getindex, [A, j, i]) end function getindex(A::Arr, i::Int, j::Symbolic{<:Integer}) @@ -282,7 +289,6 @@ function getindex(A::Arr, j::Symbolic{<:Integer}, i::Int) wrap(unwrap(A)[j, i]) end -inner_unwrap(x) = x isa AbstractArray ? unwrap.(x) : x function _matmul(A, B) A = inner_unwrap(A) B = inner_unwrap(B) @@ -325,6 +331,8 @@ end function _map(f, x, xs...) N = ndims(x) idx = makesubscripts(N) + x = inner_unwrap(x) + xs = inner_unwrap.(xs) expr = f(map(a -> a[idx...], [x, xs...])...) diff --git a/test/overloads.jl b/test/overloads.jl index 86d618bad..0bfbc0a90 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -257,6 +257,6 @@ using Symbolics: scalarize @variables X[1:3, 1:3] x sX = fill(x, 3, 3) sx = fill(x, 3) -@test isequal(scalarize(X + XX), scalarize(X) + XX) -@test isequal(scalarize(X * XX), scalarize(X) * XX) +@test isequal(scalarize(X + sX), scalarize(X) + sX) +@test isequal(scalarize(X * sX), scalarize(X) * sX) @test isequal(scalarize(X * sx), scalarize(X) * sx)