diff --git a/src/eigen/lobpcg_hyper_impl.jl b/src/eigen/lobpcg_hyper_impl.jl index 88a792ac12..7bd1c3106c 100644 --- a/src/eigen/lobpcg_hyper_impl.jl +++ b/src/eigen/lobpcg_hyper_impl.jl @@ -76,29 +76,37 @@ function Base.size(A::LazyHcat) (n, m) end -Base.Array(A::LazyHcat) = stack(A.blocks) - +Base.Array(A::LazyHcat) = stack(A.blocks) Base.adjoint(A::LazyHcat) = Adjoint(A) -@views function Base.:*(Aadj::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T} - A = Aadj.parent - rows = size(A)[2] - cols = size(B)[2] - ret = similar(A.blocks[1], rows, cols) - - orow = 0 # row offset - for blA in A.blocks - ocol = 0 # column offset - for blB in B.blocks +# Computes A*B matrix product for LazyHcat type. Special case if product is assumed to be Hermitian +@views function _mul(A::Adjoint{T,<:LazyHcat}, B::LazyHcat; hermitian=Val(false)) where {T} + Ap = A.parent + rows = size(Ap, 2) + cols = size(B, 2) + ret = similar(B.blocks[1], rows, cols) + + # Only popuplate the upper block diagonal in Hermitian case + ocol = 0 # column offset + for (ib, blB) in enumerate(B.blocks) + orow = 0 # row offset + for (ia, blA) in enumerate(Ap.blocks) + (hermitian isa Val{true} && ib < ia) && continue ret[orow .+ (1:size(blA, 2)), ocol .+ (1:size(blB, 2))] .= blA' * blB - ocol += size(blB, 2) + orow += size(blA, 2) end - orow += size(blA, 2) + ocol += size(blB, 2) + end + + if hermitian isa Val{true} + Hermitian(ret) + else + ret end - ret end -Base.:*(Aadj::Adjoint{T,<:LazyHcat}, B::AbstractMatrix) where {T} = Aadj * LazyHcat(B) +Base.:*(A::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T} = _mul(A, B) +Base.:*(A::Adjoint{T,<:LazyHcat}, B::AbstractMatrix) where {T} = A * LazyHcat(B) @views function *(Ablock::LazyHcat, B::AbstractMatrix) res = Ablock.blocks[1] * B[1:size(Ablock.blocks[1], 2), :] # First multiplication @@ -115,11 +123,16 @@ function LinearAlgebra.mul!(res::AbstractMatrix, Ablock::LazyHcat, mul!(res, Ablock*B, I, α, β) end +mul_hermi(A, B) = Hermitian(A * B) +function mul_hermi(A::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T} + _mul(A, B; hermitian=Val(true)) +end + # Perform a Rayleigh-Ritz for the N first eigenvectors. @timing function rayleigh_ritz(X, AX, N) - XAX = X' * AX - @assert !any(isnan, XAX) - rayleigh_ritz(Hermitian(XAX), N) + XAX = mul_hermi(X', AX) + @assert !any(isnan, UpperTriangular(parent(XAX))) + rayleigh_ritz(XAX, N) end @views function rayleigh_ritz(XAX::Hermitian, N) # Fallback: Use whatever is the default dense eigensolver. @@ -149,7 +162,7 @@ end # (which implies that X'BX is relatively well-conditioned, and # therefore that it is safe to cholesky it and reuse the B apply) function B_ortho!(X, BX) - O = Hermitian(X'*BX) + O = mul_hermi(X', BX) U = cholesky(O).U @assert !any(isnan, U) rdiv!(X, U) @@ -173,7 +186,7 @@ normest(M) = maximum(abs.(diag(M))) + norm(M - Diagonal(diag(M))) success = false nchol = 0 while true - O = Hermitian(X'X) + O = mul_hermi(X', X) try R = cholesky(O).U nchol += 1 diff --git a/test/lobpcg.jl b/test/lobpcg.jl index 19e463a172..46c03ad9ae 100644 --- a/test/lobpcg.jl +++ b/test/lobpcg.jl @@ -124,6 +124,7 @@ end @testitem "LOBPCG Internal data structures" setup=[TestCases] begin using DFTK using LinearAlgebra + import DFTK: mul_hermi a1 = rand(10, 5) a2 = rand(10, 2) @@ -142,4 +143,6 @@ end D = rand(10, 4) @test mul!(D,Ablock, C, 1, 0) ≈ A*C + + @test mul_hermi(Ablock', Ablock) ≈ A' * A end