Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport "Fix (l/r)mul! with Diagonal/Bidiagonal #55052" to v1.11 #55359

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))

lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
# B .= A * B
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
_muldiag_size_check(A, B)
(; dv, ev) = A
if A.uplo == 'U'
for k in axes(B,2)
for i in axes(ev,1)
B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k]
end
B[end,k] = dv[end] * B[end,k]
end
else
for k in axes(B,2)
for i in reverse(axes(dv,1)[2:end])
B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k]
end
B[1,k] = dv[1] * B[1,k]
end
end
return B
end
# B .= D * B
function lmul!(D::Diagonal, B::Bidiagonal)
_muldiag_size_check(D, B)
(; dv, ev) = B
isL = B.uplo == 'L'
dv[1] = D.diag[1] * dv[1]
for i in axes(ev,1)
ev[i] = D.diag[i + isL] * ev[i]
dv[i+1] = D.diag[i+1] * dv[i+1]
end
return B
end
# B .= B * A
function rmul!(B::AbstractMatrix, A::Bidiagonal)
_muldiag_size_check(A, B)
(; dv, ev) = A
if A.uplo == 'U'
for k in reverse(axes(dv,1)[2:end])
for i in axes(B,1)
B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1]
end
end
for i in axes(B,1)
B[i,1] *= dv[1]
end
else
for k in axes(ev,1)
for i in axes(B,1)
B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k]
end
end
for i in axes(B,1)
B[i,end] *= dv[end]
end
end
return B
end
# B .= B * D
function rmul!(B::Bidiagonal, D::Diagonal)
_muldiag_size_check(B, D)
(; dv, ev) = B
isU = B.uplo == 'U'
dv[1] *= D.diag[1]
for i in axes(ev,1)
ev[i] *= D.diag[i + isU]
dv[i+1] *= D.diag[i+1]
end
return B
end

function check_A_mul_B!_sizes(C, A, B)
mA, nA = size(A)
Expand Down
45 changes: 43 additions & 2 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,49 @@ function (*)(D::Diagonal, V::AbstractVector)
return D.diag .* V
end

rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
function rmul!(A::AbstractMatrix, D::Diagonal)
_muldiag_size_check(A, D)
for I in CartesianIndices(A)
row, col = Tuple(I)
@inbounds A[row, col] *= D.diag[col]
end
return A
end
# T .= T * D
function rmul!(T::Tridiagonal, D::Diagonal)
_muldiag_size_check(T, D)
(; dl, d, du) = T
d[1] *= D.diag[1]
for i in axes(dl,1)
dl[i] *= D.diag[i]
du[i] *= D.diag[i+1]
d[i+1] *= D.diag[i+1]
end
return T
end

function lmul!(D::Diagonal, B::AbstractVecOrMat)
_muldiag_size_check(D, B)
for I in CartesianIndices(B)
row = I[1]
@inbounds B[I] = D.diag[row] * B[I]
end
return B
end

# in-place multiplication with a diagonal
# T .= D * T
function lmul!(D::Diagonal, T::Tridiagonal)
_muldiag_size_check(D, T)
(; dl, d, du) = T
d[1] = D.diag[1] * d[1]
for i in axes(dl,1)
dl[i] = D.diag[i+1] * dl[i]
du[i] = D.diag[i] * du[i]
d[i+1] = D.diag[i+1] * d[i+1]
end
return T
end

function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out, B)
Expand Down
35 changes: 35 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,41 @@ end
@test mul!(C1, B, sv, 1, 2) == mul!(C2, B, v, 1 ,2)
end

@testset "rmul!/lmul! with banded matrices" begin
dv, ev = rand(4), rand(3)
for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L))
@testset "$(nameof(typeof(B)))" for B in (
Bidiagonal(dv, ev, :U),
Bidiagonal(dv, ev, :L),
Diagonal(dv)
)
@test_throws ArgumentError rmul!(B, A)
@test_throws ArgumentError lmul!(A, B)
end
D = Diagonal(dv)
@test rmul!(copy(A), D) ≈ A * D
@test lmul!(D, copy(A)) ≈ D * A
end
@testset "non-commutative" begin
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
for uplo in (:L, :U)
B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo)
D = Diagonal(fill(S22, size(B,2)))
@test rmul!(copy(B), D) ≈ B * D
D = Diagonal(fill(S33, size(B,1)))
@test lmul!(D, copy(B)) ≈ D * B
end

B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U)
D = Diagonal(fill(S32, 4))
@test lmul!(B, Array(D)) ≈ B * D
B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U)
@test rmul!(Array(D), B) ≈ D * B
end
end

@testset "off-band indexing error" begin
B = Bidiagonal(Vector{BigInt}(undef, 4), Vector{BigInt}(undef,3), :L)
@test_throws "cannot set entry" B[1,2] = 4
Expand Down
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1288,4 +1288,16 @@ end
@test yadj == x'
end

@testset "rmul!/lmul! with banded matrices" begin
@testset "$(nameof(typeof(B)))" for B in (
Bidiagonal(rand(4), rand(3), :L),
Tridiagonal(rand(3), rand(4), rand(3))
)
BA = Array(B)
D = Diagonal(rand(size(B,1)))
DA = Array(D)
@test rmul!(copy(B), D) ≈ B * D ≈ BA * DA
@test lmul!(D, copy(B)) ≈ D * B ≈ DA * BA
end
end
end # module TestDiagonal
18 changes: 18 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -833,4 +833,22 @@ end
@test axes(B) === (ax, ax)
end

@testset "rmul!/lmul! with banded matrices" begin
dl, d, du = rand(3), rand(4), rand(3)
A = Tridiagonal(dl, d, du)
D = Diagonal(d)
@test rmul!(copy(A), D) ≈ A * D
@test lmul!(D, copy(A)) ≈ D * A

@testset "non-commutative" begin
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3))
D = Diagonal(fill(S22, size(T,2)))
@test rmul!(copy(T), D) ≈ T * D
D = Diagonal(fill(S33, size(T,1)))
@test lmul!(D, copy(T)) ≈ D * T
end
end
end # module TestTridiagonal
18 changes: 18 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1
Base.last(r::SOneTo) = length(r)
Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")")

Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s
Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s

struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
data::A
function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}
Expand All @@ -43,10 +46,25 @@ Base.size(a::SizedArray) = size(typeof(a))
Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ
Base.axes(a::SizedArray) = map(SOneTo, size(a))
Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...)
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
Base.parent(S::SizedArray) = S.data
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data

homogenize_shape(t::Tuple) = (_homogenize_shape(first(t)), homogenize_shape(Base.tail(t))...)
homogenize_shape(::Tuple{}) = ()
_homogenize_shape(x::Integer) = x
_homogenize_shape(x::AbstractUnitRange) = length(x)
const Dims = Union{Integer, Base.OneTo, SOneTo}
function Base.similar(::Type{A}, shape::Tuple{Dims, Vararg{Dims}}) where {A<:AbstractArray}
similar(A, homogenize_shape(shape))
end
function Base.similar(::Type{A}, shape::Tuple{SOneTo, Vararg{SOneTo}}) where {A<:AbstractArray}
R = similar(A, length.(shape))
SizedArray{length.(shape)}(R)
end

const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}}

_data(S::SizedArray) = S.data
Expand Down