Skip to content

Commit

Permalink
Extend sparse kron to adjortrans of dense matrices (#474)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Nov 21, 2023
1 parent 951837f commit 8308232
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
10 changes: 5 additions & 5 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,7 @@ const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A}
const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A}
const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{<:Any,A} # AbstractTriangular{T,A}
const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}
const _DenseKronGroup = Union{Number, Vector, Matrix, AdjOrTrans{<:Any,<:VecOrMat}, _Annotated_DenseArrays}

@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
mA, nA = size(A); mB, nB = size(B)
Expand Down Expand Up @@ -1541,9 +1541,9 @@ end
end
return z
end
kron!(C::SparseMatrixCSC, A::_SparseKronGroup, B::_DenseConcatGroup) =
kron!(C::SparseMatrixCSC, A::_SparseKronGroup, B::_DenseKronGroup) =
kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
kron!(C::SparseMatrixCSC, A::_DenseConcatGroup, B::_SparseKronGroup) =
kron!(C::SparseMatrixCSC, A::_DenseKronGroup, B::_SparseKronGroup) =
kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
kron!(C::SparseMatrixCSC, A::_SparseKronGroup, B::_SparseKronGroup) =
kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
Expand Down Expand Up @@ -1580,8 +1580,8 @@ end
# extend to annotated sparse arrays, but leave out the (dense ⊗ dense)-case
kron(A::_SparseKronGroup, B::_SparseKronGroup) =
kron(convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
kron(A::_SparseKronGroup, B::_DenseConcatGroup) = kron(A, sparse(B))
kron(A::_DenseConcatGroup, B::_SparseKronGroup) = kron(sparse(A), B)
kron(A::_SparseKronGroup, B::_DenseKronGroup) = kron(A, sparse(B))
kron(A::_DenseKronGroup, B::_SparseKronGroup) = kron(sparse(A), B)
kron(A::_SparseVectorUnion, B::_AdjOrTransSparseVectorUnion) = A .* B
# disambiguation
kron(A::AbstractCompressedVector, B::AdjOrTrans{<:Any,<:AbstractCompressedVector}) = A .* B
Expand Down
22 changes: 9 additions & 13 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -751,31 +751,27 @@ end
@test Array(kron(t(a), c_di)::SparseMatrixCSC) == kron(t(a_d), c_d)
@test Array(kron(a, t(c_di))::SparseMatrixCSC) == kron(a_d, t(c_d))
@test Array(kron(t(a), t(c_di))::SparseMatrixCSC) == kron(t(a_d), t(c_d))
@test issparse(kron(c_di, y))
@test Array(kron(c_di, y)) == kron(c_di, y_d)
@test issparse(kron(x, d_di))
@test Array(kron(x, d_di)) == kron(x_d, d_di)
@test Array(kron(c_di, y)::SparseMatrixCSC) == kron(c_di, y_d)
@test Array(kron(x, d_di)::SparseMatrixCSC) == kron(x_d, d_di)
end
end
# vec ⊗ vec
@test Vector(kron(x, y)) == kron(x_d, y_d)
@test Vector(kron(x_d, y)) == kron(x_d, y_d)
@test Vector(kron(x, y_d)) == kron(x_d, y_d)
@test Vector(kron(x, y)::SparseVector) == kron(x_d, y_d)
@test Vector(kron(x_d, y)::SparseVector) == kron(x_d, y_d)
@test Vector(kron(x, y_d)::SparseVector) == kron(x_d, y_d)
for t in (identity, adjoint, transpose)
# mat ⊗ vec
@test Array(kron(t(a), y)::SparseMatrixCSC) == kron(t(a_d), y_d)
@test Array(kron(t(a_d), y)) == kron(t(a_d), y_d)
@test Array(kron(t(a_d), y)::SparseMatrixCSC) == kron(t(a_d), y_d)
@test Array(kron(t(a), y_d)::SparseMatrixCSC) == kron(t(a_d), y_d)
# vec ⊗ mat
@test Array(kron(x, t(b))::SparseMatrixCSC) == kron(x_d, t(b_d))
@test Array(kron(x_d, t(b))::SparseMatrixCSC) == kron(x_d, t(b_d))
@test Array(kron(x, t(b_d))) == kron(x_d, t(b_d))
@test Array(kron(x, t(b_d))::SparseMatrixCSC) == kron(x_d, t(b_d))
end
# vec ⊗ vec'
@test issparse(kron(v, y'))
@test issparse(kron(x, y'))
@test Array(kron(v, y')) == kron(v_d, y_d')
@test Array(kron(x, y')) == kron(x_d, y_d')
@test Array(kron(v, y')::SparseMatrixCSC) == kron(v_d, y_d')
@test Array(kron(x, y')::SparseMatrixCSC) == kron(x_d, y_d')
# test different types
z = convert(SparseVector{Float16, Int8}, y); z_d = Vector(z)
@test Vector(kron(x, z)) == kron(x_d, z_d)
Expand Down

0 comments on commit 8308232

Please sign in to comment.