diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index c45db3e90fab2..365ce8ee4bae2 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -702,6 +702,43 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang return A end +function _trirdiv!(A::UpperTriangular, B::UpperOrUnitUpperTriangular, c::Number) + n = checksize1(A, B) + for j in 1:n + for i in 1:j + @inbounds A[i, j] = B[i, j] / c + end + end + return A +end +function _trirdiv!(A::LowerTriangular, B::LowerOrUnitLowerTriangular, c::Number) + n = checksize1(A, B) + for j in 1:n + for i in j:n + @inbounds A[i, j] = B[i, j] / c + end + end + return A +end +function _trildiv!(A::UpperTriangular, c::Number, B::UpperOrUnitUpperTriangular) + n = checksize1(A, B) + for j in 1:n + for i in 1:j + @inbounds A[i, j] = c \ B[i, j] + end + end + return A +end +function _trildiv!(A::LowerTriangular, c::Number, B::LowerOrUnitLowerTriangular) + n = checksize1(A, B) + for j in 1:n + for i in j:n + @inbounds A[i, j] = c \ B[i, j] + end + end + return A +end + rmul!(A::UpperOrLowerTriangular, c::Number) = @inline _triscale!(A, A, c, MulAddMul()) lmul!(c::Number, A::UpperOrLowerTriangular) = @inline _triscale!(A, c, A, MulAddMul()) @@ -1095,7 +1132,11 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular), tstrided = t{<:Any, <:StridedMaybeAdjOrTransMat} @eval begin (*)(A::$t, x::Number) = $t(A.data*x) - (*)(A::$tstrided, x::Number) = A .* x + function (*)(A::$tstrided, x::Number) + eltype_dest = promote_op(*, eltype(A), typeof(x)) + dest = $t(similar(parent(A), eltype_dest)) + _triscale!(dest, x, A, MulAddMul()) + end function (*)(A::$unitt, x::Number) B = $t(A.data)*x @@ -1106,7 +1147,11 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular), end (*)(x::Number, A::$t) = $t(x*A.data) - (*)(x::Number, A::$tstrided) = x .* A + function (*)(x::Number, A::$tstrided) + eltype_dest = promote_op(*, typeof(x), eltype(A)) + dest = $t(similar(parent(A), eltype_dest)) + _triscale!(dest, x, A, MulAddMul()) + end function (*)(x::Number, A::$unitt) B = x*$t(A.data) @@ -1117,7 +1162,11 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular), end (/)(A::$t, x::Number) = $t(A.data/x) - (/)(A::$tstrided, x::Number) = A ./ x + function (/)(A::$tstrided, x::Number) + eltype_dest = promote_op(/, eltype(A), typeof(x)) + dest = $t(similar(parent(A), eltype_dest)) + _trirdiv!(dest, A, x) + end function (/)(A::$unitt, x::Number) B = $t(A.data)/x @@ -1129,7 +1178,11 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular), end (\)(x::Number, A::$t) = $t(x\A.data) - (\)(x::Number, A::$tstrided) = x .\ A + function (\)(x::Number, A::$tstrided) + eltype_dest = promote_op(\, typeof(x), eltype(A)) + dest = $t(similar(parent(A), eltype_dest)) + _trildiv!(dest, x, A) + end function (\)(x::Number, A::$unitt) B = x\$t(A.data) diff --git a/stdlib/LinearAlgebra/test/triangular.jl b/stdlib/LinearAlgebra/test/triangular.jl index 5ee8143e3f4bb..3f7cea91ec6d4 100644 --- a/stdlib/LinearAlgebra/test/triangular.jl +++ b/stdlib/LinearAlgebra/test/triangular.jl @@ -1180,4 +1180,18 @@ end @test V == Diagonal([1, 1]) end +@testset "preserve structure in scaling by NaN" begin + M = rand(Int8,2,2) + for (Ts, TD) in (((UpperTriangular, UnitUpperTriangular), UpperTriangular), + ((LowerTriangular, UnitLowerTriangular), LowerTriangular)) + for T in Ts + U = T(M) + for V in (U * NaN, NaN * U, U / NaN, NaN \ U) + @test V isa TD{Float64, Matrix{Float64}} + @test all(isnan, diag(V)) + end + end + end +end + end # module TestTriangular