From 48f3784d985f5774132be472809e46c5183e5a45 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 8 May 2024 11:06:19 +0200 Subject: [PATCH] further adjustments --- src/linalg.jl | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 1ce09b63..131a21bc 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -53,25 +53,19 @@ generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::Abstra spdensemul!(C, tA, tB, A, B, alpha, beta) generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, alpha::Number, beta::Number) = spdensemul!(C, tA, 'N', A, B, alpha, beta) -# legacy methods: TODO: remove -generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) = - spdensemul!(C, tA, tB, A, B, _add.alpha, _add.beta) -generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) = - spdensemul!(C, tA, tB, A, B, _add.alpha, _add.beta) -generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) = - spdensemul!(C, tA, 'N', A, B, _add.alpha, _add.beta) Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta) - if tA == 'N' + tA_uc, tB_uc = uppercase(tA), uppercase(tB) + if tA_uc == 'N' _spmatmul!(C, A, wrap(B, tB), alpha, beta) - elseif tA == 'T' + elseif tA_uc == 'T' _At_or_Ac_mul_B!(transpose, C, A, wrap(B, tB), alpha, beta) - elseif tA == 'C' + elseif tA_uc == 'C' _At_or_Ac_mul_B!(adjoint, C, A, wrap(B, tB), alpha, beta) - elseif tA in ('S', 's', 'H', 'h') && tB == 'N' + elseif tA_uc in ('S', 'H') && tB_uc == 'N' rangefun = isuppercase(tA) ? nzrangeup : nzrangelo - diagop = tA in ('S', 's') ? identity : real - odiagop = tA in ('S', 's') ? transpose : adjoint + diagop = tA_uc == 'S' ? identity : real + odiagop = tA_uc == 'S' ? transpose : adjoint T = eltype(C) _mul!(rangefun, diagop, odiagop, C, A, B, T(alpha), T(beta)) else @@ -123,9 +117,6 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) C end -# TODO:remove -generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul) = - generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number) transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint if tB == 'N' @@ -328,17 +319,12 @@ function estimate_mulsize(m::Integer, nnzA::Integer, n::Integer, nnzB::Integer, p >= 1 ? m*k : p > 0 ? Int(ceil(-expm1(log1p(-p) * n)*m*k)) : 0 # (1-(1-p)^n)*m*k end -# TODO: remove this one method -Base.@constprop :aggressive function generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::SparseMatrixCSCUnion2, B::SparseMatrixCSCUnion2, _add::MulAddMul) - A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) - B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) - _generic_matmatmul!(C, tA, tB, A, B, _add) -end Base.@constprop :aggressive function generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::SparseMatrixCSCUnion2, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number) - A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) - B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) - @stable_muladdmul _generic_matmatmul!(C, tA, tB, A, B, MulAddMul(alpha, beta)) + tA_uc, tB_uc = uppercase(tA), uppercase(tB) + Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) + Bnew, tb = tB_uc in ('S', 'H') ? (wrap(B, tB), oftype(tB, 'N')) : (B, tB) + @stable_muladdmul _generic_matmatmul!(C, ta, tb, Anew, Bnew, MulAddMul(alpha, beta)) end function _generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul)