From cb43fd59edda93e79101fede8f635f2448875bb5 Mon Sep 17 00:00:00 2001 From: Tommy Hofmann Date: Wed, 24 Jan 2024 10:49:20 +0100 Subject: [PATCH] feat: fix generalized indexing - make A[i::Int, js::Array] return a Vector and support views --- src/Matrix.jl | 28 ++++++++++++++++++++++++---- src/generic/GenericTypes.jl | 5 +++++ src/generic/Matrix.jl | 19 ++++++++++++++++++- test/generic/Matrix-test.jl | 22 ++++++++++++++++++---- 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/src/Matrix.jl b/src/Matrix.jl index fca9e99a18..44b2a15ecf 100644 --- a/src/Matrix.jl +++ b/src/Matrix.jl @@ -57,7 +57,9 @@ end _checkbounds(i::Int, j::Int) = 1 <= j <= i -function _checkbounds(A, i::Int, j::Int) +_checkbounds(i::Int, j::AbstractVector{Int}) = all(jj -> 1 <= jj <= i, j) + +function _checkbounds(A, i::Union{Int, AbstractVector{Int}}, j::Union{Int, AbstractVector{Int}}) (_checkbounds(nrows(A), i) && _checkbounds(ncols(A), j)) || Base.throw_boundserror(A, (i, j)) end @@ -386,18 +388,36 @@ function getindex(M::MatElem, rows::AbstractVector{Int}, cols::AbstractVector{In return A end +function getindex(M::MatElem, i::Int, cols::AbstractVector{Int}) + _checkbounds(M, i, cols) + A = Vector{elem_type(base_ring(M))}(undef, length(cols)) + for j in eachindex(cols) + A[j] = deepcopy(M[i, cols[j]]) + end + return A +end + +function getindex(M::MatElem, rows::AbstractVector{Int}, j::Int) + _checkbounds(M, rows, j) + A = Vector{elem_type(base_ring(M))}(undef, length(rows)) + for i in eachindex(rows) + A[i] = deepcopy(M[rows[i], j]) + end + return A + end + getindex(M::MatElem, rows::Union{Int,Colon,AbstractVector{Int}}, cols::Union{Int,Colon,AbstractVector{Int}}) = M[_to_indices(M, rows, cols)...] function _to_indices(x, rows, cols) if rows isa Integer - rows = rows:rows + rows = rows elseif rows isa Colon rows = 1:nrows(x) end if cols isa Integer - cols = cols:cols + cols = cols elseif cols isa Colon cols = 1:ncols(x) end @@ -2519,7 +2539,7 @@ function trace_of_prod(M::MatElem, N::MatElem) is_square(M) && is_square(N) || error("Not a square matrix in trace") d = zero(base_ring(M)) for i = 1:nrows(M) - d += (M[i, :] * N[:, i])[1, 1] + d += (M[i:i, :] * N[:, i:i])[1, 1] end return d end diff --git a/src/generic/GenericTypes.jl b/src/generic/GenericTypes.jl index 9a30c57657..88f31a166b 100644 --- a/src/generic/GenericTypes.jl +++ b/src/generic/GenericTypes.jl @@ -1096,6 +1096,11 @@ struct MatSpaceView{T <: NCRingElement, V, W} <: Mat{T} base_ring::NCRing end +struct MatSpaceVecView{T <: NCRingElement, V, W} <: AbstractVector{T} + entries::SubArray{T, 1, Matrix{T}, V, W} + base_ring::NCRing +end + ############################################################################### # # MatRing / MatRingElem diff --git a/src/generic/Matrix.jl b/src/generic/Matrix.jl index 701f3f0366..e83ac7196d 100644 --- a/src/generic/Matrix.jl +++ b/src/generic/Matrix.jl @@ -100,10 +100,18 @@ function deepcopy_internal(d::MatSpaceView{T}, dict::IdDict) where T <: NCRingEl return MatSpaceView(deepcopy_internal(d.entries, dict), d.base_ring) end -function Base.view(M::Mat{T}, rows::AbstractUnitRange{Int}, cols::AbstractUnitRange{Int}) where T <: NCRingElement +function Base.view(M::Mat{T}, rows::Union{Colon, AbstractVector{Int}}, cols::Union{Colon, AbstractVector{Int}}) where T <: NCRingElement return MatSpaceView(view(M.entries, rows, cols), M.base_ring) end +function Base.view(M::Mat{T}, rows::Int, cols::Union{Colon, AbstractVector{Int}}) where T <: NCRingElement + return MatSpaceVecView(view(M.entries, rows, cols), M.base_ring) +end + +function Base.view(M::Mat{T}, rows::Union{Colon, AbstractVector{Int}}, cols::Int) where T <: NCRingElement + return MatSpaceVecView(view(M.entries, rows, cols), M.base_ring) +end + ################################################################################ # # Size, axes and is_square @@ -228,3 +236,12 @@ function AbstractAlgebra.mul!(A::Mat{T}, B::Mat{T}, C::Mat{T}, f::Bool = false) return A end +Base.length(V::MatSpaceVecView) = length(V.entries) + +Base.getindex(V::MatSpaceVecView, i::Int) = V.entries[i] + +Base.setindex!(V::MatSpaceVecView{T}, z::T, i::Int) where {T} = (V.entries[i] = z) + +Base.setindex!(V::MatSpaceVecView, z::RingElement, i::Int) = setindex!(V.entries, V.base_ring(z), i) + +Base.size(V::MatSpaceVecView) = (length(V.entries), ) diff --git a/test/generic/Matrix-test.jl b/test/generic/Matrix-test.jl index 65ea3f0570..d33ab6282b 100644 --- a/test/generic/Matrix-test.jl +++ b/test/generic/Matrix-test.jl @@ -1420,8 +1420,8 @@ end Q = inv(P) PA = P*A - @test PA == reduce(vcat, [A[Q[i], :] for i in 1:nrows(A)]) - @test PA == reduce(vcat, A[Q[i], :] for i in 1:nrows(A)) + @test PA == reduce(vcat, [A[Q[i]:Q[i], :] for i in 1:nrows(A)]) + @test PA == reduce(vcat, A[Q[i]:Q[i], :] for i in 1:nrows(A)) @test PA == S(reduce(vcat, A.entries[Q[i], :] for i in 1:nrows(A))) @test A == Q*(P*A) end @@ -4022,12 +4022,26 @@ end @test fflu(N3) == fflu(M) # tests that deepcopy is correct @test M2 == M - for i in [ 1, 1:2, : ], j in [ 1, 1:2, : ] + for i in [ 1:1, 1:2, : ], j in [ 1:1, 1:2, : ] v = @view M[i,j] @test v isa Generic.MatSpaceView @test M[i,j] == v end + M2 = deepcopy(M) + M3 = @view M2[2, 1:2] + @test length(M3) == 2 + @test M3 == [2, 3] + M3[2] = 5 + @test M2 == ZZ[1 2 3; 2 5 4; 3 4 5] + + M2 = deepcopy(M) + M3 = @view M2[1:3, 3] + @test length(M3) == 3 + @test M3 == [3, 4, 5] + M3[1] = 10 + @test M2 == ZZ[1 2 10; 2 3 4; 3 4 5] + # Test views over noncommutative ring R = matrix_ring(ZZ, 2) @@ -4035,7 +4049,7 @@ end M = rand(S, -10:10) - for i in [ 1, 1:2, : ], j in [ 1, 1:2, : ] + for i in [ 1:1, 1:2, : ], j in [ 1:1, 1:2, : ] v = @view M[i,j] @test v isa Generic.MatSpaceView @test M[i,j] == v