Skip to content

Commit

Permalink
Improve type-stability in SymTridiagonal triu!/tril!
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Aug 31, 2024
1 parent e22e4de commit 1ae34db
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 18 deletions.
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ function tril!(M::SymTridiagonal{T}, k::Integer=0) where T
return Tridiagonal(M.ev,M.dv,zero(M.ev))
elseif k == 0
return Tridiagonal(M.ev,M.dv,zero(M.ev))
elseif k >= 1
else # if k >= 1
return Tridiagonal(M.ev,M.dv,copy(M.ev))
end
end
Expand All @@ -391,7 +391,7 @@ function triu!(M::SymTridiagonal{T}, k::Integer=0) where T
return Tridiagonal(zero(M.ev),M.dv,M.ev)
elseif k == 0
return Tridiagonal(zero(M.ev),M.dv,M.ev)
elseif k <= -1
else # if k <= -1
return Tridiagonal(M.ev,M.dv,copy(M.ev))
end
end
Expand Down
48 changes: 32 additions & 16 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,43 @@ end
@test_throws ArgumentError tril!(SymTridiagonal(d, dl), n)
@test_throws ArgumentError tril!(Tridiagonal(dl, d, du), -n - 2)
@test_throws ArgumentError tril!(Tridiagonal(dl, d, du), n)
@test tril(SymTridiagonal(d,dl)) == Tridiagonal(dl,d,zerosdl)
@test tril(SymTridiagonal(d,dl),1) == Tridiagonal(dl,d,dl)
@test tril(SymTridiagonal(d,dl),-1) == Tridiagonal(dl,zerosd,zerosdl)
@test tril(SymTridiagonal(d,dl),-2) == Tridiagonal(zerosdl,zerosd,zerosdl)
@test tril(Tridiagonal(dl,d,du)) == Tridiagonal(dl,d,zerosdu)
@test tril(Tridiagonal(dl,d,du),1) == Tridiagonal(dl,d,du)
@test tril(Tridiagonal(dl,d,du),-1) == Tridiagonal(dl,zerosd,zerosdu)
@test tril(Tridiagonal(dl,d,du),-2) == Tridiagonal(zerosdl,zerosd,zerosdu)
@test @inferred(tril(SymTridiagonal(d,dl))) == Tridiagonal(dl,d,zerosdl)
@test @inferred(tril(SymTridiagonal(d,dl),1)) == Tridiagonal(dl,d,dl)
@test @inferred(tril(SymTridiagonal(d,dl),-1)) == Tridiagonal(dl,zerosd,zerosdl)
@test @inferred(tril(SymTridiagonal(d,dl),-2)) == Tridiagonal(zerosdl,zerosd,zerosdl)
@test @inferred(tril(Tridiagonal(dl,d,du))) == Tridiagonal(dl,d,zerosdu)
@test @inferred(tril(Tridiagonal(dl,d,du),1)) == Tridiagonal(dl,d,du)
@test @inferred(tril(Tridiagonal(dl,d,du),-1)) == Tridiagonal(dl,zerosd,zerosdu)
@test @inferred(tril(Tridiagonal(dl,d,du),-2)) == Tridiagonal(zerosdl,zerosd,zerosdu)
@test @inferred(tril!(copy(SymTridiagonal(d,dl)))) == Tridiagonal(dl,d,zerosdl)
@test @inferred(tril!(copy(SymTridiagonal(d,dl)),1)) == Tridiagonal(dl,d,dl)
@test @inferred(tril!(copy(SymTridiagonal(d,dl)),-1)) == Tridiagonal(dl,zerosd,zerosdl)
@test @inferred(tril!(copy(SymTridiagonal(d,dl)),-2)) == Tridiagonal(zerosdl,zerosd,zerosdl)
@test @inferred(tril!(copy(Tridiagonal(dl,d,du)))) == Tridiagonal(dl,d,zerosdu)
@test @inferred(tril!(copy(Tridiagonal(dl,d,du)),1)) == Tridiagonal(dl,d,du)
@test @inferred(tril!(copy(Tridiagonal(dl,d,du)),-1)) == Tridiagonal(dl,zerosd,zerosdu)
@test @inferred(tril!(copy(Tridiagonal(dl,d,du)),-2)) == Tridiagonal(zerosdl,zerosd,zerosdu)

@test_throws ArgumentError triu!(SymTridiagonal(d, dl), -n)
@test_throws ArgumentError triu!(SymTridiagonal(d, dl), n + 2)
@test_throws ArgumentError triu!(Tridiagonal(dl, d, du), -n)
@test_throws ArgumentError triu!(Tridiagonal(dl, d, du), n + 2)
@test triu(SymTridiagonal(d,dl)) == Tridiagonal(zerosdl,d,dl)
@test triu(SymTridiagonal(d,dl),-1) == Tridiagonal(dl,d,dl)
@test triu(SymTridiagonal(d,dl),1) == Tridiagonal(zerosdl,zerosd,dl)
@test triu(SymTridiagonal(d,dl),2) == Tridiagonal(zerosdl,zerosd,zerosdl)
@test triu(Tridiagonal(dl,d,du)) == Tridiagonal(zerosdl,d,du)
@test triu(Tridiagonal(dl,d,du),-1) == Tridiagonal(dl,d,du)
@test triu(Tridiagonal(dl,d,du),1) == Tridiagonal(zerosdl,zerosd,du)
@test triu(Tridiagonal(dl,d,du),2) == Tridiagonal(zerosdl,zerosd,zerosdu)
@test @inferred(triu(SymTridiagonal(d,dl))) == Tridiagonal(zerosdl,d,dl)
@test @inferred(triu(SymTridiagonal(d,dl),-1)) == Tridiagonal(dl,d,dl)
@test @inferred(triu(SymTridiagonal(d,dl),1)) == Tridiagonal(zerosdl,zerosd,dl)
@test @inferred(triu(SymTridiagonal(d,dl),2)) == Tridiagonal(zerosdl,zerosd,zerosdl)
@test @inferred(triu(Tridiagonal(dl,d,du))) == Tridiagonal(zerosdl,d,du)
@test @inferred(triu(Tridiagonal(dl,d,du),-1)) == Tridiagonal(dl,d,du)
@test @inferred(triu(Tridiagonal(dl,d,du),1)) == Tridiagonal(zerosdl,zerosd,du)
@test @inferred(triu(Tridiagonal(dl,d,du),2)) == Tridiagonal(zerosdl,zerosd,zerosdu)
@test @inferred(triu!(copy(SymTridiagonal(d,dl)))) == Tridiagonal(zerosdl,d,dl)
@test @inferred(triu!(copy(SymTridiagonal(d,dl)),-1)) == Tridiagonal(dl,d,dl)
@test @inferred(triu!(copy(SymTridiagonal(d,dl)),1)) == Tridiagonal(zerosdl,zerosd,dl)
@test @inferred(triu!(copy(SymTridiagonal(d,dl)),2)) == Tridiagonal(zerosdl,zerosd,zerosdl)
@test @inferred(triu!(copy(Tridiagonal(dl,d,du)))) == Tridiagonal(zerosdl,d,du)
@test @inferred(triu!(copy(Tridiagonal(dl,d,du)),-1)) == Tridiagonal(dl,d,du)
@test @inferred(triu!(copy(Tridiagonal(dl,d,du)),1)) == Tridiagonal(zerosdl,zerosd,du)
@test @inferred(triu!(copy(Tridiagonal(dl,d,du)),2)) == Tridiagonal(zerosdl,zerosd,zerosdu)

@test !istril(SymTridiagonal(d,dl))
@test istril(SymTridiagonal(d,zerosdl))
Expand Down

0 comments on commit 1ae34db

Please sign in to comment.