Skip to content

Commit

Permalink
Minor optimization of the LOBPCG solver (#1037)
Browse files Browse the repository at this point in the history
  • Loading branch information
abussy authored Feb 5, 2025
1 parent a6bd941 commit 24eb717
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
55 changes: 34 additions & 21 deletions src/eigen/lobpcg_hyper_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,29 +76,37 @@ function Base.size(A::LazyHcat)
(n, m)
end

Base.Array(A::LazyHcat) = stack(A.blocks)

Base.Array(A::LazyHcat) = stack(A.blocks)
Base.adjoint(A::LazyHcat) = Adjoint(A)

@views function Base.:*(Aadj::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T}
A = Aadj.parent
rows = size(A)[2]
cols = size(B)[2]
ret = similar(A.blocks[1], rows, cols)

orow = 0 # row offset
for blA in A.blocks
ocol = 0 # column offset
for blB in B.blocks
# Computes A*B matrix product for LazyHcat type. Special case if product is assumed to be Hermitian
@views function _mul(A::Adjoint{T,<:LazyHcat}, B::LazyHcat; hermitian=Val(false)) where {T}
Ap = A.parent
rows = size(Ap, 2)
cols = size(B, 2)
ret = similar(B.blocks[1], rows, cols)

# Only popuplate the upper block diagonal in Hermitian case
ocol = 0 # column offset
for (ib, blB) in enumerate(B.blocks)
orow = 0 # row offset
for (ia, blA) in enumerate(Ap.blocks)
(hermitian isa Val{true} && ib < ia) && continue
ret[orow .+ (1:size(blA, 2)), ocol .+ (1:size(blB, 2))] .= blA' * blB
ocol += size(blB, 2)
orow += size(blA, 2)
end
orow += size(blA, 2)
ocol += size(blB, 2)
end

if hermitian isa Val{true}
Hermitian(ret)
else
ret
end
ret
end

Base.:*(Aadj::Adjoint{T,<:LazyHcat}, B::AbstractMatrix) where {T} = Aadj * LazyHcat(B)
Base.:*(A::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T} = _mul(A, B)
Base.:*(A::Adjoint{T,<:LazyHcat}, B::AbstractMatrix) where {T} = A * LazyHcat(B)

@views function *(Ablock::LazyHcat, B::AbstractMatrix)
res = Ablock.blocks[1] * B[1:size(Ablock.blocks[1], 2), :] # First multiplication
Expand All @@ -115,11 +123,16 @@ function LinearAlgebra.mul!(res::AbstractMatrix, Ablock::LazyHcat,
mul!(res, Ablock*B, I, α, β)
end

mul_hermi(A, B) = Hermitian(A * B)
function mul_hermi(A::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T}
_mul(A, B; hermitian=Val(true))
end

# Perform a Rayleigh-Ritz for the N first eigenvectors.
@timing function rayleigh_ritz(X, AX, N)
XAX = X' * AX
@assert !any(isnan, XAX)
rayleigh_ritz(Hermitian(XAX), N)
XAX = mul_hermi(X', AX)
@assert !any(isnan, UpperTriangular(parent(XAX)))
rayleigh_ritz(XAX, N)
end
@views function rayleigh_ritz(XAX::Hermitian, N)
# Fallback: Use whatever is the default dense eigensolver.
Expand Down Expand Up @@ -149,7 +162,7 @@ end
# (which implies that X'BX is relatively well-conditioned, and
# therefore that it is safe to cholesky it and reuse the B apply)
function B_ortho!(X, BX)
O = Hermitian(X'*BX)
O = mul_hermi(X', BX)
U = cholesky(O).U
@assert !any(isnan, U)
rdiv!(X, U)
Expand All @@ -173,7 +186,7 @@ normest(M) = maximum(abs.(diag(M))) + norm(M - Diagonal(diag(M)))
success = false
nchol = 0
while true
O = Hermitian(X'X)
O = mul_hermi(X', X)
try
R = cholesky(O).U
nchol += 1
Expand Down
3 changes: 3 additions & 0 deletions test/lobpcg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ end
@testitem "LOBPCG Internal data structures" setup=[TestCases] begin
using DFTK
using LinearAlgebra
import DFTK: mul_hermi

a1 = rand(10, 5)
a2 = rand(10, 2)
Expand All @@ -142,4 +143,6 @@ end

D = rand(10, 4)
@test mul!(D,Ablock, C, 1, 0) A*C

@test mul_hermi(Ablock', Ablock) A' * A
end

0 comments on commit 24eb717

Please sign in to comment.