Skip to content

Commit 8308232

Browse files
authored
Extend sparse kron to adjortrans of dense matrices (#474)
1 parent 951837f commit 8308232

File tree

2 files changed

+14
-18
lines changed

2 files changed

+14
-18
lines changed

src/linalg.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -1483,7 +1483,7 @@ const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A}
14831483
const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A}
14841484
const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{<:Any,A} # AbstractTriangular{T,A}
14851485
const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
1486-
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}
1486+
const _DenseKronGroup = Union{Number, Vector, Matrix, AdjOrTrans{<:Any,<:VecOrMat}, _Annotated_DenseArrays}
14871487

14881488
@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
14891489
mA, nA = size(A); mB, nB = size(B)
@@ -1541,9 +1541,9 @@ end
15411541
end
15421542
return z
15431543
end
1544-
kron!(C::SparseMatrixCSC, A::_SparseKronGroup, B::_DenseConcatGroup) =
1544+
kron!(C::SparseMatrixCSC, A::_SparseKronGroup, B::_DenseKronGroup) =
15451545
kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
1546-
kron!(C::SparseMatrixCSC, A::_DenseConcatGroup, B::_SparseKronGroup) =
1546+
kron!(C::SparseMatrixCSC, A::_DenseKronGroup, B::_SparseKronGroup) =
15471547
kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
15481548
kron!(C::SparseMatrixCSC, A::_SparseKronGroup, B::_SparseKronGroup) =
15491549
kron!(C, convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
@@ -1580,8 +1580,8 @@ end
15801580
# extend to annotated sparse arrays, but leave out the (dense ⊗ dense)-case
15811581
kron(A::_SparseKronGroup, B::_SparseKronGroup) =
15821582
kron(convert(SparseMatrixCSC, A), convert(SparseMatrixCSC, B))
1583-
kron(A::_SparseKronGroup, B::_DenseConcatGroup) = kron(A, sparse(B))
1584-
kron(A::_DenseConcatGroup, B::_SparseKronGroup) = kron(sparse(A), B)
1583+
kron(A::_SparseKronGroup, B::_DenseKronGroup) = kron(A, sparse(B))
1584+
kron(A::_DenseKronGroup, B::_SparseKronGroup) = kron(sparse(A), B)
15851585
kron(A::_SparseVectorUnion, B::_AdjOrTransSparseVectorUnion) = A .* B
15861586
# disambiguation
15871587
kron(A::AbstractCompressedVector, B::AdjOrTrans{<:Any,<:AbstractCompressedVector}) = A .* B

test/linalg.jl

+9-13
Original file line numberDiff line numberDiff line change
@@ -751,31 +751,27 @@ end
751751
@test Array(kron(t(a), c_di)::SparseMatrixCSC) == kron(t(a_d), c_d)
752752
@test Array(kron(a, t(c_di))::SparseMatrixCSC) == kron(a_d, t(c_d))
753753
@test Array(kron(t(a), t(c_di))::SparseMatrixCSC) == kron(t(a_d), t(c_d))
754-
@test issparse(kron(c_di, y))
755-
@test Array(kron(c_di, y)) == kron(c_di, y_d)
756-
@test issparse(kron(x, d_di))
757-
@test Array(kron(x, d_di)) == kron(x_d, d_di)
754+
@test Array(kron(c_di, y)::SparseMatrixCSC) == kron(c_di, y_d)
755+
@test Array(kron(x, d_di)::SparseMatrixCSC) == kron(x_d, d_di)
758756
end
759757
end
760758
# vec ⊗ vec
761-
@test Vector(kron(x, y)) == kron(x_d, y_d)
762-
@test Vector(kron(x_d, y)) == kron(x_d, y_d)
763-
@test Vector(kron(x, y_d)) == kron(x_d, y_d)
759+
@test Vector(kron(x, y)::SparseVector) == kron(x_d, y_d)
760+
@test Vector(kron(x_d, y)::SparseVector) == kron(x_d, y_d)
761+
@test Vector(kron(x, y_d)::SparseVector) == kron(x_d, y_d)
764762
for t in (identity, adjoint, transpose)
765763
# mat ⊗ vec
766764
@test Array(kron(t(a), y)::SparseMatrixCSC) == kron(t(a_d), y_d)
767-
@test Array(kron(t(a_d), y)) == kron(t(a_d), y_d)
765+
@test Array(kron(t(a_d), y)::SparseMatrixCSC) == kron(t(a_d), y_d)
768766
@test Array(kron(t(a), y_d)::SparseMatrixCSC) == kron(t(a_d), y_d)
769767
# vec ⊗ mat
770768
@test Array(kron(x, t(b))::SparseMatrixCSC) == kron(x_d, t(b_d))
771769
@test Array(kron(x_d, t(b))::SparseMatrixCSC) == kron(x_d, t(b_d))
772-
@test Array(kron(x, t(b_d))) == kron(x_d, t(b_d))
770+
@test Array(kron(x, t(b_d))::SparseMatrixCSC) == kron(x_d, t(b_d))
773771
end
774772
# vec ⊗ vec'
775-
@test issparse(kron(v, y'))
776-
@test issparse(kron(x, y'))
777-
@test Array(kron(v, y')) == kron(v_d, y_d')
778-
@test Array(kron(x, y')) == kron(x_d, y_d')
773+
@test Array(kron(v, y')::SparseMatrixCSC) == kron(v_d, y_d')
774+
@test Array(kron(x, y')::SparseMatrixCSC) == kron(x_d, y_d')
779775
# test different types
780776
z = convert(SparseVector{Float16, Int8}, y); z_d = Vector(z)
781777
@test Vector(kron(x, z)) == kron(x_d, z_d)

0 commit comments

Comments
 (0)