Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor optimization of the LOBPCG solver #1037

Merged
merged 9 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions src/eigen/lobpcg_hyper_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,29 +76,37 @@
(n, m)
end

Base.Array(A::LazyHcat) = stack(A.blocks)

Base.Array(A::LazyHcat) = stack(A.blocks)

Check warning on line 79 in src/eigen/lobpcg_hyper_impl.jl

View check run for this annotation

Codecov / codecov/patch

src/eigen/lobpcg_hyper_impl.jl#L79

Added line #L79 was not covered by tests
Base.adjoint(A::LazyHcat) = Adjoint(A)

@views function Base.:*(Aadj::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T}
A = Aadj.parent
rows = size(A)[2]
cols = size(B)[2]
ret = similar(A.blocks[1], rows, cols)

orow = 0 # row offset
for blA in A.blocks
ocol = 0 # column offset
for blB in B.blocks
# Computes A*B matrix product for LazyHcat type. Special case if product is assumed to be Hermitian
@views function _mul(A::Adjoint{T,<:LazyHcat}, B::LazyHcat; hermitian=Val(false)) where {T}
Ap = A.parent
rows = size(Ap, 2)
cols = size(B, 2)
ret = similar(B.blocks[1], rows, cols)

# Only popuplate the upper block diagonal in Hermitian case
ocol = 0 # column offset
for (ib, blB) in enumerate(B.blocks)
orow = 0 # row offset
for (ia, blA) in enumerate(Ap.blocks)
(hermitian isa Val{true} && ib < ia) && continue
ret[orow .+ (1:size(blA, 2)), ocol .+ (1:size(blB, 2))] .= blA' * blB
ocol += size(blB, 2)
orow += size(blA, 2)
end
orow += size(blA, 2)
ocol += size(blB, 2)
end

if hermitian isa Val{true}
Hermitian(ret)
else
ret

Check warning on line 104 in src/eigen/lobpcg_hyper_impl.jl

View check run for this annotation

Codecov / codecov/patch

src/eigen/lobpcg_hyper_impl.jl#L104

Added line #L104 was not covered by tests
end
ret
end

Base.:*(Aadj::Adjoint{T,<:LazyHcat}, B::AbstractMatrix) where {T} = Aadj * LazyHcat(B)
Base.:*(A::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T} = _mul(A, B)
Base.:*(A::Adjoint{T,<:LazyHcat}, B::AbstractMatrix) where {T} = A * LazyHcat(B)

@views function *(Ablock::LazyHcat, B::AbstractMatrix)
res = Ablock.blocks[1] * B[1:size(Ablock.blocks[1], 2), :] # First multiplication
Expand All @@ -115,11 +123,16 @@
mul!(res, Ablock*B, I, α, β)
end

mul_hermi(A, B) = Hermitian(A * B)
function mul_hermi(A::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T}
_mul(A, B; hermitian=Val(true))
end

# Perform a Rayleigh-Ritz for the N first eigenvectors.
@timing function rayleigh_ritz(X, AX, N)
XAX = X' * AX
@assert !any(isnan, XAX)
rayleigh_ritz(Hermitian(XAX), N)
XAX = mul_hermi(X', AX)
@assert !any(isnan, UpperTriangular(parent(XAX)))
rayleigh_ritz(XAX, N)
end
@views function rayleigh_ritz(XAX::Hermitian, N)
# Fallback: Use whatever is the default dense eigensolver.
Expand Down Expand Up @@ -149,7 +162,7 @@
# (which implies that X'BX is relatively well-conditioned, and
# therefore that it is safe to cholesky it and reuse the B apply)
function B_ortho!(X, BX)
O = Hermitian(X'*BX)
O = mul_hermi(X', BX)

Check warning on line 165 in src/eigen/lobpcg_hyper_impl.jl

View check run for this annotation

Codecov / codecov/patch

src/eigen/lobpcg_hyper_impl.jl#L165

Added line #L165 was not covered by tests
U = cholesky(O).U
@assert !any(isnan, U)
rdiv!(X, U)
Expand All @@ -173,7 +186,7 @@
success = false
nchol = 0
while true
O = Hermitian(X'X)
O = mul_hermi(X', X)
try
R = cholesky(O).U
nchol += 1
Expand Down
3 changes: 3 additions & 0 deletions test/lobpcg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ end
@testitem "LOBPCG Internal data structures" setup=[TestCases] begin
using DFTK
using LinearAlgebra
import DFTK: mul_hermi

a1 = rand(10, 5)
a2 = rand(10, 2)
Expand All @@ -142,4 +143,6 @@ end

D = rand(10, 4)
@test mul!(D,Ablock, C, 1, 0) ≈ A*C

@test mul_hermi(Ablock', Ablock) ≈ A' * A
end
Loading