Skip to content

Commit 1b7a095

Browse files
fredrikekreandreasnoack
authored andcommitted
unify methods for cholesky (JuliaLang#21595)
1 parent d762038 commit 1b7a095

File tree

1 file changed

+9
-36
lines changed

1 file changed

+9
-36
lines changed

base/linalg/cholesky.jl

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616
# supported for the four LAPACK element types. For other types, e.g. BigFloats Val{true} will
1717
# give an error. It is required that the input is Hermitian (including real symmetric) either
1818
# through the Hermitian and Symmetric views or exact symmetric or Hermitian elements which
19-
# is checked for and an error is thrown if the check fails. The dispatch
20-
# is further complicated by a limitation in the formulation of Unions. The relevant union
21-
# would be Union{Symmetric{T<:Real,S}, Hermitian} but, right now, it doesn't work in Julia
22-
# so we'll have to define methods for the two elements of the union separately.
19+
# is checked for and an error is thrown if the check fails.
2320

2421
# FixMe? The dispatch below seems overly complicated. One simplification could be to
2522
# merge the two Cholesky types into one. It would remove the need for Val completely but
@@ -121,9 +118,7 @@ non_hermitian_error(f) = throw(ArgumentError("matrix is not symmetric/" *
121118

122119
# chol!. Destructive methods for computing Cholesky factor of real symmetric or Hermitian
123120
# matrix
124-
chol!(A::Hermitian) =
125-
_chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular)
126-
chol!(A::Symmetric{<:Real,<:StridedMatrix}) =
121+
chol!(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) =
127122
_chol!(A.uplo == 'U' ? A.data : LinAlg.copytri!(A.data, 'L', true), UpperTriangular)
128123
function chol!(A::StridedMatrix)
129124
ishermitian(A) || non_hermitian_error("chol!")
@@ -134,7 +129,7 @@ end
134129

135130
# chol. Non-destructive methods for computing Cholesky factor of a real symmetric or
136131
# Hermitian matrix. Promotes elements to a type that is stable under square roots.
137-
function chol(A::Hermitian)
132+
function chol(A::RealHermSymComplexHerm)
138133
T = promote_type(typeof(chol(one(eltype(A)))), Float32)
139134
AA = similar(A, T, size(A))
140135
if A.uplo == 'U'
@@ -144,16 +139,6 @@ function chol(A::Hermitian)
144139
end
145140
chol!(Hermitian(AA, :U))
146141
end
147-
function chol(A::Symmetric{T,<:AbstractMatrix}) where T<:Real
148-
TT = promote_type(typeof(chol(one(T))), Float32)
149-
AA = similar(A, TT, size(A))
150-
if A.uplo == 'U'
151-
copy!(AA, A.data)
152-
else
153-
Base.ctranspose!(AA, A.data)
154-
end
155-
chol!(Hermitian(AA, :U))
156-
end
157142

158143
## for StridedMatrices, check that matrix is symmetric/Hermitian
159144
"""
@@ -206,14 +191,7 @@ chol(x::Number, args...) = _chol!(x, nothing)
206191
# cholfact!. Destructive methods for computing Cholesky factorization of real symmetric
207192
# or Hermitian matrix
208193
## No pivoting
209-
function cholfact!(A::Hermitian, ::Type{Val{false}})
210-
if A.uplo == 'U'
211-
Cholesky(_chol!(A.data, UpperTriangular).data, 'U')
212-
else
213-
Cholesky(_chol!(A.data, LowerTriangular).data, 'L')
214-
end
215-
end
216-
function cholfact!(A::Symmetric{<:Real}, ::Type{Val{false}})
194+
function cholfact!(A::RealHermSymComplexHerm, ::Type{Val{false}})
217195
if A.uplo == 'U'
218196
Cholesky(_chol!(A.data, UpperTriangular).data, 'U')
219197
else
@@ -248,8 +226,8 @@ function cholfact!(A::StridedMatrix, uplo::Symbol, ::Type{Val{false}})
248226
end
249227

250228
### Default to no pivoting (and storing of upper factor) when not explicit
251-
cholfact!(A::Hermitian) = cholfact!(A, Val{false})
252-
cholfact!(A::Symmetric{<:Real}) = cholfact!(A, Val{false})
229+
cholfact!(A::RealHermSymComplexHerm) = cholfact!(A, Val{false})
230+
253231
#### for StridedMatrices, check that matrix is symmetric/Hermitian
254232
function cholfact!(A::StridedMatrix, uplo::Symbol = :U)
255233
ishermitian(A) || non_hermitian_error("cholfact!")
@@ -288,9 +266,7 @@ end
288266
# cholfact. Non-destructive methods for computing Cholesky factorization of real symmetric
289267
# or Hermitian matrix
290268
## No pivoting
291-
cholfact(A::Hermitian, ::Type{Val{false}}) =
292-
cholfact!(copy_oftype(A, promote_type(typeof(chol(one(eltype(A)))),Float32)), Val{false})
293-
cholfact(A::Symmetric{<:Real,<:StridedMatrix}, ::Type{Val{false}}) =
269+
cholfact(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}, ::Type{Val{false}}) =
294270
cholfact!(copy_oftype(A, promote_type(typeof(chol(one(eltype(A)))),Float32)), Val{false})
295271

296272
### for StridedMatrices, check that matrix is symmetric/Hermitian
@@ -342,8 +318,8 @@ function cholfact(A::StridedMatrix, uplo::Symbol, ::Type{Val{false}})
342318
end
343319

344320
### Default to no pivoting (and storing of upper factor) when not explicit
345-
cholfact(A::Hermitian) = cholfact(A, Val{false})
346-
cholfact(A::Symmetric{<:Real,<:StridedMatrix}) = cholfact(A, Val{false})
321+
cholfact(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) = cholfact(A, Val{false})
322+
347323
#### for StridedMatrices, check that matrix is symmetric/Hermitian
348324
function cholfact(A::StridedMatrix, uplo::Symbol = :U)
349325
ishermitian(A) || non_hermitian_error("cholfact")
@@ -352,9 +328,6 @@ end
352328

353329

354330
## With pivoting
355-
cholfact(A::Hermitian, ::Type{Val{true}}; tol = 0.0) =
356-
cholfact!(copy_oftype(A, promote_type(typeof(chol(one(eltype(A)))),Float32)),
357-
Val{true}; tol = tol)
358331
cholfact(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}, ::Type{Val{true}}; tol = 0.0) =
359332
cholfact!(copy_oftype(A, promote_type(typeof(chol(one(eltype(A)))),Float32)),
360333
Val{true}; tol = tol)

0 commit comments

Comments
 (0)