Skip to content

Commit

Permalink
Fix kron indexing for types without a unique zero
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Oct 18, 2024
1 parent f1990e2 commit f9d0c17
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
50 changes: 25 additions & 25 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -700,16 +700,16 @@ end
zerofilled = true
end
end
@inbounds for i = 1:nA, j = 1:nB
for i in eachindex(valA), j in eachindex(valB)
idx = (i-1)*nB+j
C[idx, idx] = valA[i] * valB[j]
@inbounds C[idx, idx] = valA[i] * valB[j]
end
if !zerofilled
for j in 1:nA, i in 1:mA
for j in axes(A,2), i in axes(A,1)
Δrow, Δcol = (i-1)*mB, (j-1)*nB
for k in 1:nB, l in 1:mB
for k in axes(B,2), l in axes(B,1)
i == j && k == l && continue
C[Δrow + l, Δcol + k] = A[i,j] * B[l,k]
@inbounds C[Δrow + l, Δcol + k] = A[i,j] * B[l,k]
end
end
end
Expand Down Expand Up @@ -749,24 +749,24 @@ end
end
end
m = 1
@inbounds for j = 1:nA
A_jj = A[j,j]
for k = 1:nB
for l = 1:mB
C[m] = A_jj * B[l,k]
for j in axes(A,2)
A_jj = @inbounds A[j,j]
for k in axes(B,2)
for l in axes(B,1)
@inbounds C[m] = A_jj * B[l,k]
m += 1
end
m += (nA - 1) * mB
end
if !zerofilled
# populate the zero elements
for i in 1:mA
for i in axes(A,1)
i == j && continue
A_ij = A[i, j]
A_ij = @inbounds A[i, j]
Δrow, Δcol = (i-1)*mB, (j-1)*nB
for k in 1:nB, l in 1:nA
B_lk = B[l, k]
C[Δrow + l, Δcol + k] = A_ij * B_lk
for k in axes(B,2), l in axes(B,1)
B_lk = @inbounds B[l, k]
@inbounds C[Δrow + l, Δcol + k] = A_ij * B_lk
end
end
end
Expand All @@ -792,23 +792,23 @@ end
end
end
m = 1
@inbounds for j = 1:nA
for l = 1:mB
Bll = B[l,l]
for i = 1:mA
C[m] = A[i,j] * Bll
for j in axes(A,2)
for l in axes(B,1)
Bll = @inbounds B[l,l]
for i in axes(A,1)
@inbounds C[m] = A[i,j] * Bll
m += nB
end
m += 1
end
if !zerofilled
for i in 1:mA
A_ij = A[i, j]
for i in axes(A,1)
A_ij = @inbounds A[i, j]
Δrow, Δcol = (i-1)*mB, (j-1)*nB
for k in 1:nB, l in 1:mB
for k in axes(B,2), l in axes(B,1)
l == k && continue
B_lk = B[l, k]
C[Δrow + l, Δcol + k] = A_ij * B_lk
B_lk = @inbounds B[l, k]
@inbounds C[Δrow + l, Δcol + k] = A_ij * B_lk
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ Random.seed!(1)
D3 = Diagonal(convert(Vector{elty}, rand(n÷2)))
DM3= Matrix(D3)
@test Matrix(kron(D, D3)) kron(DM, DM3)
M4 = rand(elty, n÷2, n÷2)
M4 = rand(elty, size(D3,1) + 1, size(D3,2) + 2) # choose a different size from D3
@test kron(D3, M4) kron(DM3, M4)
@test kron(M4, D3) kron(M4, DM3)
X = [ones(1,1) for i in 1:2, j in 1:2]
Expand Down Expand Up @@ -1392,7 +1392,7 @@ end
end

@testset "zeros in kron with block matrices" begin
D = Diagonal(1:2)
D = Diagonal(1:4)
B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2)
@test kron(D, B) == kron(Array(D), B)
@test kron(B, D) == kron(B, Array(D))
Expand Down

0 comments on commit f9d0c17

Please sign in to comment.