diff --git a/src/adjtrans.jl b/src/adjtrans.jl index d5dbe76d..a34e53c2 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -141,6 +141,25 @@ function mul!( conj!(res) end +function mul!( + res::AbstractMatrix, + op::AdjointLinearOperator{T, S}, + m::AbstractMatrix, + α, + β, +) where {T, S} + p = op.parent + (size(m, 1) == size(p, 1) && size(res, 1) == size(p, 2) && size(m, 2) == size(res, 2)) || + throw(LinearOperatorException("shape mismatch")) + if ishermitian(p) + return mul!(res, p, m, α, β) + elseif p.ctprod! !== nothing + return p.ctprod!(res, m, α, β) + else + error("Not implemented") + end +end + function mul!( res::AbstractVector, op::TransposeLinearOperator{T, S}, @@ -188,6 +207,25 @@ function mul!( conj!(res) end +function mul!( + res::AbstractMatrix, + op::TransposeLinearOperator{T, S}, + m::AbstractMatrix, + α, + β, +) where {T, S} + p = op.parent + (size(m, 1) == size(p, 1) && size(res, 1) == size(p, 2) && size(m, 2) == size(res, 2)) || + throw(LinearOperatorException("shape mismatch")) + if issymmetric(p) + return mul!(res, p, m, α, β) + elseif p.tprod! !== nothing + return p.tprod!(res, m, α, β) + else + error("Not implemented") + end +end + function mul!( res::AbstractVector, op::ConjugateLinearOperator{T, S}, @@ -200,6 +238,18 @@ function mul!( conj!(res) end +function mul!( + res::AbstractMatrix, + op::ConjugateLinearOperator{T, S}, + v::AbstractMatrix, + α, + β, +) where {T, S} + p = op.parent + mul!(res, p, conj.(v), α, β) + conj!(res) +end + -(op::AdjointLinearOperator) = adjoint(-op.parent) -(op::TransposeLinearOperator) = transpose(-op.parent) -(op::ConjugateLinearOperator) = conj(-op.parent) diff --git a/src/constructors.jl b/src/constructors.jl index 41030e3e..c41852d4 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -102,6 +102,8 @@ op = LinearOperator(Float64, 2, 2, false, false, (res, v) -> mul!(res, A, v), (res, w) -> mul!(res, A', w)) ``` + +The 3-args `mul!` also works when applying the operator on a matrix. """ function LinearOperator( ::Type{T}, diff --git a/src/operations.jl b/src/operations.jl index ca3d2584..18259d40 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -32,7 +32,11 @@ function mul!(res::AbstractVector, op::AbstractLinearOperator{T}, v::AbstractVec end end -function mul!(res::AbstractVector, op::AbstractLinearOperator, v::AbstractVector{T}) where {T} +function mul!(res::AbstractMatrix, op::AbstractLinearOperator, m::AbstractMatrix{T}, α, β) where {T} + op.prod!(res, m, α, β) +end + +function mul!(res::AbstractVecOrMat, op::AbstractLinearOperator, v::AbstractVecOrMat{T}) where {T} mul!(res, op, v, one(T), zero(T)) end @@ -73,6 +77,26 @@ function *( return v_wrapper(res) end +# Apply an operator to a matrix (only in-place, since operator * matrix is a matrix). + +function mul!( + res::Adjoint{S1, M1}, + m::Adjoint{S2, M2}, + op::AbstractLinearOperator{T}, +) where {T, S1, S2, M1 <: AbstractMatrix{S1}, M2 <: AbstractMatrix{S2}} + mul!(adjoint(res), adjoint(op), adjoint(m)) + return res +end + +function mul!( + res::Transpose{S1, M1}, + m::Transpose{S2, M2}, + op::AbstractLinearOperator{T}, +) where {T, S1, S2, M1 <: AbstractMatrix{S1}, M2 <: AbstractMatrix{S2}} + mul!(transpose(res), transpose(op), transpose(m)) + return res +end + # Unary operations. +(op::AbstractLinearOperator) = op @@ -95,11 +119,11 @@ function -(op::AbstractLinearOperator{T}) where {T} end function prod_op!( - res::AbstractVector, + res::AbstractVecOrMat, op1::AbstractLinearOperator, op2::AbstractLinearOperator, - vtmp::AbstractVector, - v::AbstractVector, + vtmp::AbstractVecOrMat, + v::AbstractVecOrMat, α, β, ) @@ -162,10 +186,10 @@ end # Operator + operator. function sum_prod!( - res::AbstractVector, + res::AbstractVecOrMat, op1::AbstractLinearOperator, op2::AbstractLinearOperator{T}, - v::AbstractVector, + v::AbstractVecOrMat, α, β, ) where {T} diff --git a/test/test_linop.jl b/test/test_linop.jl index 64b666c6..71039448 100644 --- a/test/test_linop.jl +++ b/test/test_linop.jl @@ -61,6 +61,18 @@ function test_linop() @test(norm(transpose(u) * A - transpose(u) * op) <= rtol * norm(u)) @test(typeof(u' * op * v) <: Number) @test(norm(u' * A * v - u' * op * v) <= rtol * norm(u)) + + mv = hcat(v, -2v) + mu = hcat(u, -2u) + res_mat = similar(mu) + res_trans = similar(mv) + res_adj = similar(mv) + mul!(res_mat, op, mv) + mul!(res_trans, transpose(op), mu) + mul!(res_adj, op', mu) + @test(norm(A * mv - res_mat) <= rtol * norm(mv)) + @test(norm(transpose(A) * mu - res_trans) <= rtol * norm(mu)) + @test(norm(A' * mu - res_adj) <= rtol * norm(mu)) end A3 = Hermitian(A2' * A2)