diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index ae2411f2..34f1e9d6 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -3126,12 +3126,15 @@ function getindex(A::AbstractSparseMatrixCSC{Tv,Ti}, I::AbstractArray) where {Tv colptrB[colB] = 1 idxB = 1 + CartIndsA = CartesianIndices(szA) + CartIndsB = CartesianIndices(szB) + for i in 1:n @boundscheck checkbounds(A, I[i]) - row,col = Base._ind2sub(szA, I[i]) + row,col = Tuple(CartIndsA[I[i]]) for r in colptrA[col]:(colptrA[col+1]-1) @inbounds if rowvalA[r] == row - rowB,colB = Base._ind2sub(szB, i) + rowB,colB = Tuple(CartIndsB[i]) colptrB[colB+1] += 1 rowvalB[idxB] = rowB nzvalB[idxB] = nzvalA[r] @@ -3635,13 +3638,15 @@ function setindex!(A::AbstractSparseMatrixCSC, x::AbstractArray, Ix::AbstractVec isa(x, AbstractArray) && setindex_shape_check(x, length(I)) + CartIndsA = CartesianIndices(szA) + lastcol = 0 (nrowA, ncolA) = szA @inbounds for xidx in 1:n sxidx = S[xidx] (sxidx < n) && (I[sxidx] == I[sxidx+1]) && continue - row,col = Base._ind2sub(szA, I[sxidx]) + row,col = Tuple(CartIndsA[I[sxidx]]) v = x[sxidx] if col > lastcol diff --git a/src/sparsevector.jl b/src/sparsevector.jl index f82d6a36..e79b52e0 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -779,9 +779,12 @@ function getindex(A::AbstractSparseMatrixCSC{Tv}, I::AbstractUnitRange) where Tv rowvalB = Vector{Int}(undef, nnzB) nzvalB = Vector{Tv}(undef, nnzB) + CartIndsA = CartesianIndices(szA) + LinIndsA = LinearIndices(szA) + if nnzB > 0 - rowstart,colstart = Base._ind2sub(szA, first(I)) - rowend,colend = Base._ind2sub(szA, last(I)) + rowstart,colstart = Tuple(CartIndsA[first(I)]) + rowend,colend = Tuple(CartIndsA[last(I)]) idxB = 1 @inbounds for col in colstart:colend @@ -790,7 +793,7 @@ function getindex(A::AbstractSparseMatrixCSC{Tv}, I::AbstractUnitRange) where Tv for r in colptrA[col]:(colptrA[col+1]-1) rowA = rowvalA[r] if minrow <= rowA <= maxrow - rowvalB[idxB] = Base._sub2ind(szA, rowA, col) - first(I) + 1 + rowvalB[idxB] = LinIndsA[rowA, col] - first(I) + 1 nzvalB[idxB] = nzvalA[r] idxB += 1 end @@ -818,9 +821,11 @@ function getindex(A::AbstractSparseMatrixCSC{Tv,Ti}, I::AbstractVector) where {T rowvalB = Vector{Ti}(undef, nnzB) nzvalB = Vector{Tv}(undef, nnzB) + CartIndsA = CartesianIndices(szA) + idxB = 1 for i in 1:n - row,col = Base._ind2sub(szA, I[i]) + row,col = Tuple(CartIndsA[I[i]]) for r in colptrA[col]:(colptrA[col+1]-1) @inbounds if rowvalA[r] == row if idxB <= nnzB