From 80b31b09f6fff5e0c55db77387c3ac1be0af4e1f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 16 May 2024 09:01:27 +0200 Subject: [PATCH 1/3] Add mul! for matrices --- src/adjtrans.jl | 2 ++ src/cat.jl | 4 ++++ src/kron.jl | 2 ++ src/operations.jl | 39 ++++++++++++++++++++++++++++++++++----- src/special-operators.jl | 15 +++++++++++++++ 5 files changed, 57 insertions(+), 5 deletions(-) diff --git a/src/adjtrans.jl b/src/adjtrans.jl index d5dbe76d..f2ccdfb5 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -200,6 +200,8 @@ function mul!( conj!(res) end +# TODO: overload the above for matrices? + -(op::AdjointLinearOperator) = adjoint(-op.parent) -(op::TransposeLinearOperator) = transpose(-op.parent) -(op::ConjugateLinearOperator) = conj(-op.parent) diff --git a/src/cat.jl b/src/cat.jl index 764c7893..39b0b410 100644 --- a/src/cat.jl +++ b/src/cat.jl @@ -32,6 +32,8 @@ function hcat_ctprod!( mul!(view(res, (Ancol + 1):nV), B, u, α, β) end +# TODO: overload the above for matrices? + function hcat(A::AbstractLinearOperator, B::AbstractLinearOperator) size(A, 1) == size(B, 1) || throw(LinearOperatorException("hcat: inconsistent row sizes")) @@ -91,6 +93,8 @@ function vcat_ctprod!( mul!(res, B, view(v, (Anrow + 1):nV), α, one(T)) end +# TODO: overload the above for matrices? + function vcat(A::AbstractLinearOperator, B::AbstractLinearOperator) size(A, 2) == size(B, 2) || throw(LinearOperatorException("vcat: inconsistent column sizes")) diff --git a/src/kron.jl b/src/kron.jl index 9be09d16..d847e6cc 100644 --- a/src/kron.jl +++ b/src/kron.jl @@ -44,6 +44,8 @@ function kron(A::AbstractLinearOperator, B::AbstractLinearOperator) return LinearOperator{T}(nrow, ncol, symm, herm, prod!, tprod!, ctprod!) end +# TODO: overload the above for matrices? + kron(A::AbstractMatrix, B::AbstractLinearOperator) = kron(LinearOperator(A), B) kron(A::AbstractLinearOperator, B::AbstractMatrix) = kron(A, LinearOperator(B)) diff --git a/src/operations.jl b/src/operations.jl index ca3d2584..7f3c3fb0 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -32,10 +32,19 @@ function mul!(res::AbstractVector, op::AbstractLinearOperator{T}, v::AbstractVec end end +function mul!(res::AbstractMatrix, op::AbstractLinearOperator{T}, m::AbstractMatrix, α, β) where {T} + # TODO: how to handle storage? + error("5-argument `mul!` is not defined between a linear operator and a matrix.") +end + function mul!(res::AbstractVector, op::AbstractLinearOperator, v::AbstractVector{T}) where {T} mul!(res, op, v, one(T), zero(T)) end +function mul!(res::AbstractMatrix, op::AbstractLinearOperator, m::AbstractMatrix{T}) where {T} + mul!(res, op, m, one(T), zero(T)) +end + # Apply an operator to a vector. function *(op::AbstractLinearOperator{T}, v::AbstractVector{S}) where {T, S} nrow, ncol = size(op) @@ -73,6 +82,26 @@ function *( return v_wrapper(res) end +# Apply an operator to a matrix (only in-place, since operator * matrix is a matrix). + +function mul!( + res::Adjoint{S1, M1}, + m::Adjoint{S2, M2}, + op::AbstractLinearOperator{T}, +) where {T, S1, S2, M1 <: AbstractMatrix{S1}, M2 <: AbstractMatrix{S2}} + mul!(adjoint(res), adjoint(op), adjoint(m)) + return res +end + +function mul!( + res::Transpose{S1, M1}, + m::Transpose{S2, M2}, + op::AbstractLinearOperator{T}, +) where {T, S1, S2, M1 <: AbstractMatrix{S1}, M2 <: AbstractMatrix{S2}} + mul!(transpose(res), transpose(op), transpose(m)) + return res +end + # Unary operations. +(op::AbstractLinearOperator) = op @@ -95,11 +124,11 @@ function -(op::AbstractLinearOperator{T}) where {T} end function prod_op!( - res::AbstractVector, + res::AbstractVecOrMat, op1::AbstractLinearOperator, op2::AbstractLinearOperator, - vtmp::AbstractVector, - v::AbstractVector, + vtmp::AbstractVecOrMat, + v::AbstractVecOrMat, α, β, ) @@ -162,10 +191,10 @@ end # Operator + operator. function sum_prod!( - res::AbstractVector, + res::AbstractVecOrMat, op1::AbstractLinearOperator, op2::AbstractLinearOperator{T}, - v::AbstractVector, + v::AbstractVecOrMat, α, β, ) where {T} diff --git a/src/special-operators.jl b/src/special-operators.jl index c58e857a..7e5470ba 100644 --- a/src/special-operators.jl +++ b/src/special-operators.jl @@ -43,6 +43,8 @@ function mulOpEye!(res, v, α, β::T, n_min) where {T} end end +# TODO: overload the above for matrices? + """ opEye(T, n; S = Vector{T}) opEye(n) @@ -84,6 +86,8 @@ function mulOpOnes!(res, v, α, β::T) where {T} end end +# TODO: overload the above for matrices? + """ opOnes(T, nrow, ncol; S = Vector{T}) opOnes(nrow, ncol) @@ -107,6 +111,8 @@ function mulOpZeros!(res, v, α, β::T) where {T} end end +# TODO: overload the above for matrices? + """ opZeros(T, nrow, ncol; S = Vector{T}) opZeros(nrow, ncol) @@ -130,6 +136,8 @@ function mulSquareOpDiagonal!(res, d, v, α, β::T) where {T} end end +# TODO: overload the above for matrices? + """ opDiagonal(d) @@ -149,6 +157,9 @@ function mulOpDiagonal!(res, d, v, α, β::T, n_min) where {T} end res[(n_min + 1):end] .= 0 end + +# TODO: overload the above for matrices? + """ opDiagonal(nrow, ncol, d) @@ -173,6 +184,8 @@ function multRestrict!(res, I, u, α, β) res[I] = u end +# TODO: overload the above for matrices? + """ Z = opRestriction(I, ncol) Z = opRestriction(:, ncol) @@ -289,3 +302,5 @@ function BlockDiagonalOperator(ops...; S = promote_type(storage_type.(ops)...)) args5 = all((has_args5(op) for op ∈ ops)) CompositeLinearOperator(T, nrow, ncol, symm, herm, prod!, tprod!, ctprod!, args5, S = S) end + +# TODO: overload the above for matrices? From 9dc35f8390bc4f8b0c94f1ce8c8bfe0e9e883e95 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 16 May 2024 10:40:03 +0200 Subject: [PATCH 2/3] First version with passing minimal tests --- src/adjtrans.jl | 50 +++++++++++++++++++++++++++++++++++++++- src/cat.jl | 4 ---- src/constructors.jl | 2 ++ src/kron.jl | 2 -- src/operations.jl | 11 +++------ src/special-operators.jl | 13 ----------- test/test_linop.jl | 12 ++++++++++ 7 files changed, 66 insertions(+), 28 deletions(-) diff --git a/src/adjtrans.jl b/src/adjtrans.jl index f2ccdfb5..a34e53c2 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -141,6 +141,25 @@ function mul!( conj!(res) end +function mul!( + res::AbstractMatrix, + op::AdjointLinearOperator{T, S}, + m::AbstractMatrix, + α, + β, +) where {T, S} + p = op.parent + (size(m, 1) == size(p, 1) && size(res, 1) == size(p, 2) && size(m, 2) == size(res, 2)) || + throw(LinearOperatorException("shape mismatch")) + if ishermitian(p) + return mul!(res, p, m, α, β) + elseif p.ctprod! !== nothing + return p.ctprod!(res, m, α, β) + else + error("Not implemented") + end +end + function mul!( res::AbstractVector, op::TransposeLinearOperator{T, S}, @@ -188,6 +207,25 @@ function mul!( conj!(res) end +function mul!( + res::AbstractMatrix, + op::TransposeLinearOperator{T, S}, + m::AbstractMatrix, + α, + β, +) where {T, S} + p = op.parent + (size(m, 1) == size(p, 1) && size(res, 1) == size(p, 2) && size(m, 2) == size(res, 2)) || + throw(LinearOperatorException("shape mismatch")) + if issymmetric(p) + return mul!(res, p, m, α, β) + elseif p.tprod! !== nothing + return p.tprod!(res, m, α, β) + else + error("Not implemented") + end +end + function mul!( res::AbstractVector, op::ConjugateLinearOperator{T, S}, @@ -200,7 +238,17 @@ function mul!( conj!(res) end -# TODO: overload the above for matrices? +function mul!( + res::AbstractMatrix, + op::ConjugateLinearOperator{T, S}, + v::AbstractMatrix, + α, + β, +) where {T, S} + p = op.parent + mul!(res, p, conj.(v), α, β) + conj!(res) +end -(op::AdjointLinearOperator) = adjoint(-op.parent) -(op::TransposeLinearOperator) = transpose(-op.parent) diff --git a/src/cat.jl b/src/cat.jl index 39b0b410..764c7893 100644 --- a/src/cat.jl +++ b/src/cat.jl @@ -32,8 +32,6 @@ function hcat_ctprod!( mul!(view(res, (Ancol + 1):nV), B, u, α, β) end -# TODO: overload the above for matrices? - function hcat(A::AbstractLinearOperator, B::AbstractLinearOperator) size(A, 1) == size(B, 1) || throw(LinearOperatorException("hcat: inconsistent row sizes")) @@ -93,8 +91,6 @@ function vcat_ctprod!( mul!(res, B, view(v, (Anrow + 1):nV), α, one(T)) end -# TODO: overload the above for matrices? - function vcat(A::AbstractLinearOperator, B::AbstractLinearOperator) size(A, 2) == size(B, 2) || throw(LinearOperatorException("vcat: inconsistent column sizes")) diff --git a/src/constructors.jl b/src/constructors.jl index 41030e3e..c41852d4 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -102,6 +102,8 @@ op = LinearOperator(Float64, 2, 2, false, false, (res, v) -> mul!(res, A, v), (res, w) -> mul!(res, A', w)) ``` + +The 3-args `mul!` also works when applying the operator on a matrix. """ function LinearOperator( ::Type{T}, diff --git a/src/kron.jl b/src/kron.jl index d847e6cc..9be09d16 100644 --- a/src/kron.jl +++ b/src/kron.jl @@ -44,8 +44,6 @@ function kron(A::AbstractLinearOperator, B::AbstractLinearOperator) return LinearOperator{T}(nrow, ncol, symm, herm, prod!, tprod!, ctprod!) end -# TODO: overload the above for matrices? - kron(A::AbstractMatrix, B::AbstractLinearOperator) = kron(LinearOperator(A), B) kron(A::AbstractLinearOperator, B::AbstractMatrix) = kron(A, LinearOperator(B)) diff --git a/src/operations.jl b/src/operations.jl index 7f3c3fb0..18259d40 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -32,19 +32,14 @@ function mul!(res::AbstractVector, op::AbstractLinearOperator{T}, v::AbstractVec end end -function mul!(res::AbstractMatrix, op::AbstractLinearOperator{T}, m::AbstractMatrix, α, β) where {T} - # TODO: how to handle storage? - error("5-argument `mul!` is not defined between a linear operator and a matrix.") +function mul!(res::AbstractMatrix, op::AbstractLinearOperator, m::AbstractMatrix{T}, α, β) where {T} + op.prod!(res, m, α, β) end -function mul!(res::AbstractVector, op::AbstractLinearOperator, v::AbstractVector{T}) where {T} +function mul!(res::AbstractVecOrMat, op::AbstractLinearOperator, v::AbstractVecOrMat{T}) where {T} mul!(res, op, v, one(T), zero(T)) end -function mul!(res::AbstractMatrix, op::AbstractLinearOperator, m::AbstractMatrix{T}) where {T} - mul!(res, op, m, one(T), zero(T)) -end - # Apply an operator to a vector. function *(op::AbstractLinearOperator{T}, v::AbstractVector{S}) where {T, S} nrow, ncol = size(op) diff --git a/src/special-operators.jl b/src/special-operators.jl index 7e5470ba..b47717e0 100644 --- a/src/special-operators.jl +++ b/src/special-operators.jl @@ -43,8 +43,6 @@ function mulOpEye!(res, v, α, β::T, n_min) where {T} end end -# TODO: overload the above for matrices? - """ opEye(T, n; S = Vector{T}) opEye(n) @@ -86,8 +84,6 @@ function mulOpOnes!(res, v, α, β::T) where {T} end end -# TODO: overload the above for matrices? - """ opOnes(T, nrow, ncol; S = Vector{T}) opOnes(nrow, ncol) @@ -111,8 +107,6 @@ function mulOpZeros!(res, v, α, β::T) where {T} end end -# TODO: overload the above for matrices? - """ opZeros(T, nrow, ncol; S = Vector{T}) opZeros(nrow, ncol) @@ -136,8 +130,6 @@ function mulSquareOpDiagonal!(res, d, v, α, β::T) where {T} end end -# TODO: overload the above for matrices? - """ opDiagonal(d) @@ -157,9 +149,6 @@ function mulOpDiagonal!(res, d, v, α, β::T, n_min) where {T} end res[(n_min + 1):end] .= 0 end - -# TODO: overload the above for matrices? - """ opDiagonal(nrow, ncol, d) @@ -184,8 +173,6 @@ function multRestrict!(res, I, u, α, β) res[I] = u end -# TODO: overload the above for matrices? - """ Z = opRestriction(I, ncol) Z = opRestriction(:, ncol) diff --git a/test/test_linop.jl b/test/test_linop.jl index 64b666c6..71039448 100644 --- a/test/test_linop.jl +++ b/test/test_linop.jl @@ -61,6 +61,18 @@ function test_linop() @test(norm(transpose(u) * A - transpose(u) * op) <= rtol * norm(u)) @test(typeof(u' * op * v) <: Number) @test(norm(u' * A * v - u' * op * v) <= rtol * norm(u)) + + mv = hcat(v, -2v) + mu = hcat(u, -2u) + res_mat = similar(mu) + res_trans = similar(mv) + res_adj = similar(mv) + mul!(res_mat, op, mv) + mul!(res_trans, transpose(op), mu) + mul!(res_adj, op', mu) + @test(norm(A * mv - res_mat) <= rtol * norm(mv)) + @test(norm(transpose(A) * mu - res_trans) <= rtol * norm(mu)) + @test(norm(A' * mu - res_adj) <= rtol * norm(mu)) end A3 = Hermitian(A2' * A2) From 2e60f39b26faf769c974e46b97a1c60be6d52d38 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 16 May 2024 19:26:08 +0200 Subject: [PATCH 3/3] Remove comment --- src/special-operators.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/special-operators.jl b/src/special-operators.jl index b47717e0..c58e857a 100644 --- a/src/special-operators.jl +++ b/src/special-operators.jl @@ -289,5 +289,3 @@ function BlockDiagonalOperator(ops...; S = promote_type(storage_type.(ops)...)) args5 = all((has_args5(op) for op ∈ ops)) CompositeLinearOperator(T, nrow, ncol, symm, herm, prod!, tprod!, ctprod!, args5, S = S) end - -# TODO: overload the above for matrices?