Skip to content

Commit

Permalink
Added suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
abussy committed Feb 3, 2025
1 parent e7dd283 commit 33b4b60
Showing 1 changed file with 35 additions and 30 deletions.
65 changes: 35 additions & 30 deletions src/eigen/lobpcg_hyper_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,47 +80,51 @@ Base.Array(A::LazyHcat) = stack(A.blocks)

Base.adjoint(A::LazyHcat) = Adjoint(A)

@views function Base.:*(Aadj::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T}
A = Aadj.parent
rows = size(A, 2)
# Computes A*B matrix product for LazyHcat type. Special case if product is assumed to be Hermitian
@views function _mul(A::Adjoint{T,<:LazyHcat}, B::LazyHcat; hermitian=Val(false)) where {T}
Ap = A.parent
rows = size(Ap, 2)
cols = size(B, 2)
ret = similar(A.blocks[1], rows, cols)
ret = similar(B.blocks[1], rows, cols)

orow = 0 # row offset
for blA in A.blocks
if hermitian isa Val{true}
# Only popuplate the upper block diagonal in Hermitian case
ocol = 0 # column offset
for blB in B.blocks
ret[orow .+ (1:size(blA, 2)), ocol .+ (1:size(blB, 2))] .= blA' * blB
for (ib, blB) in enumerate(B.blocks)
orow = 0 # row offset
for (ia, blA) in enumerate(Ap.blocks)
ib < ia && continue
ret[orow .+ (1:size(blA, 2)), ocol .+ (1:size(blB, 2))] .= blA' * blB
orow += size(blA, 2)
end
ocol += size(blB, 2)
end
Hermitian(ret)
else
ocol = 0 # column offset
for (ib, blB) in enumerate(B.blocks)
orow = 0 # row offset
for (ia, blA) in enumerate(Ap.blocks)
ret[orow .+ (1:size(blA, 2)), ocol .+ (1:size(blB, 2))] .= blA' * blB
orow += size(blA, 2)
end
ocol += size(blB, 2)
end
orow += size(blA, 2)
ret
end
ret
end

Base.:*(Aadj::Adjoint{T,<:LazyHcat}, B::AbstractMatrix) where {T} = Aadj * LazyHcat(B)
@views function Base.:*(A::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T}
_mul(A, B)
end

# Special case of Hermitian result: can only actively compute the block upper diagonal
@views function mul_hermi(Aadj::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T}
A = Aadj.parent
rows = size(A, 2)
cols = size(B, 2)
ret = similar(B.blocks[1], rows, cols)
Base.:*(A::Adjoint{T,<:LazyHcat}, B::AbstractMatrix) where {T} = A * LazyHcat(B)

ocol = 0 # column offset
for (ib, blB) in enumerate(B.blocks)
orow = 0 # row offset
for (ia, blA) in enumerate(A.blocks)
ib < ia && continue
ret[orow .+ (1:size(blA, 2)), ocol .+ (1:size(blB, 2))] .= blA' * blB
orow += size(blA, 2)
end
ocol += size(blB, 2)
end
Hermitian(ret)
end
mul_hermi(A, B) = Hermitian(A * B)

mul_hermi(Aadj::AbstractArray{T}, B::AbstractArray{T}) where {T} = Hermitian(Aadj * B)
@views function mul_hermi(A::Adjoint{T,<:LazyHcat}, B::LazyHcat) where {T}
_mul(A, B; hermitian=Val(true))
end

@views function *(Ablock::LazyHcat, B::AbstractMatrix)
res = Ablock.blocks[1] * B[1:size(Ablock.blocks[1], 2), :] # First multiplication
Expand All @@ -140,6 +144,7 @@ end
# Perform a Rayleigh-Ritz for the N first eigenvectors.
@timing function rayleigh_ritz(X, AX, N)
XAX = mul_hermi(X', AX)
@assert !any(isnan, UpperTriangular(parent(XAX)))
rayleigh_ritz(XAX, N)
end
@views function rayleigh_ritz(XAX::Hermitian, N)
Expand Down

0 comments on commit 33b4b60

Please sign in to comment.