Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move all PW basis k-point data to new KpointSet struct #1021

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ include("Model.jl")
include("structure.jl")
include("bzmesh.jl")
include("fft.jl")
include("Kpoint.jl")
include("kpoints.jl")
include("PlaneWaveBasis.jl")
include("orbitals.jl")
include("input_output.jl")
Expand Down
84 changes: 0 additions & 84 deletions src/Kpoint.jl

This file was deleted.

176 changes: 33 additions & 143 deletions src/PlaneWaveBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct PlaneWaveBasis{T,
VT <: Real,
Arch <: AbstractArchitecture,
FFTtype <: FFTGrid{T, VT},
T_kpt_G_vecs <: AbstractVector{Vec3{Int}},
KpointSettype <: KpointSet{T},
} <: AbstractBasis{T}

# T is the default type to express data, VT the corresponding bare value type (i.e. not dual)
Expand All @@ -47,35 +47,8 @@ struct PlaneWaveBasis{T,
# A FFTGrid containing all necessary data for FFT opertations related to this basis
fft_grid::FFTtype

## MPI-local information of the kpoints this processor treats
# In principle, irreducible kpoints (although some kpoints might be duplicated in parallel runs).
# In the case of collinear spin, this lists all the spin up, then all the spin down
kpoints::Vector{Kpoint{T, T_kpt_G_vecs}}
# BZ integration weights, summing up to model.n_spin_components
kweights::Vector{T}

## (MPI-global) information on the k-point grid
## These fields are not actually used in computation, but can be used to reconstruct a basis
# Monkhorst-Pack grid used to generate the k-points, or nothing for custom k-points
kgrid::AbstractKgrid
# Full list of (non spin doubled) k-point coordinates in the irreducible BZ (duplicates possible)
# Best to use the irreducible_kcoords_global() and irreducible_kweights_global() functions
# to insure none of the k-points are duplicated
kcoords_global::Vector{Vec3{T}}
kweights_global::Vector{T}

# Number of irreducible k-points in the basis. If there are more MPI ranks than irreducible
# k-points, some are duplicated over the MPI ranks (with adjusted weight). In such a case
# n_irreducible_kpoints < length(kcoords_global)
n_irreducible_kpoints::Int

## Setup for MPI-distributed processing over k-points
comm_kpts::MPI.Comm # communicator for the kpoints distribution
krange_thisproc::Vector{UnitRange{Int}} # Indices of kpoints treated explicitly by this
# # processor in the global kcoords array. To allow for contiguous array
# # indexing, this is given as a unit range for spin-up and spin-down
krange_allprocs::Vector{Vector{UnitRange{Int}}} # Same as above, but one entry per rank
krange_thisproc_allspin::Vector{Int} # Indexing version == reduce(vcat, krange_thisproc)
# A KpointSet containing all data related to the Kpoints (weight, coordinates, MPI distribution)
kpoint_set::KpointSettype

## Information on the hardware and device used for computations.
architecture::Arch
Expand All @@ -102,12 +75,35 @@ Base.Broadcast.broadcastable(basis::PlaneWaveBasis) = Ref(basis)

Base.eltype(::PlaneWaveBasis{T}) where {T} = T


function Kpoint(basis::PlaneWaveBasis, coordinate::AbstractVector, spin::Int)
Kpoint(spin, coordinate, basis.model.recip_lattice, basis.fft_size, basis.Ecut;
basis.variational, basis.architecture)
end

# Allow direct access to basis.kpoint_set fields (TODO: this is temporary, need to
# discuss merits of this, versus explicitly writing basis.kpoint_set everywhere,
# versus writing functions such as kpoints(basis) or kweights(basis))
# For now, allows minimum modification of the code
function Base.getproperty(basis::PlaneWaveBasis, symbol::Symbol)
if symbol in fieldnames(KpointSet)
return getfield(basis.kpoint_set, symbol)
else
return getfield(basis, symbol)
end
end

# Forward all references to kpoints from the PlaneWaveBasis to its KpointSet
function irreducible_kcoords_global(basis::PlaneWaveBasis)
irreducible_kcoords_global(basis.kpoint_set)
end

function irreducible_kweights_global(basis::PlaneWaveBasis)
irreducible_kweights_global(basis.kpoint_set)
end

function weighted_ksum(basis::PlaneWaveBasis, array)
weighted_ksum(basis.kpoint_set, array)
end

# Returns the kpoint at given coordinate. If outside the Brillouin zone, it is created
# from an equivalent kpoint in the basis (also returned)
Expand Down Expand Up @@ -164,90 +160,25 @@ function PlaneWaveBasis(model::Model{T}, Ecut::Real, fft_size::Tuple{Int, Int, I
end
symmetries = symmetries_preserving_kgrid(symmetries, kgrid)

# Build the irreducible k-point coordinates
if use_symmetries_for_kpoint_reduction
kdata = irreducible_kcoords(kgrid, symmetries)
else
kdata = irreducible_kcoords(kgrid, [one(SymOp)])
end

# Init MPI, and store MPI-global values for reference
MPI.Init()
kcoords_global = convert(Vector{Vec3{T}}, kdata.kcoords)
kweights_global = convert(Vector{T}, kdata.kweights)
# Create a set of kpoints (incl. MPI parallelization info)
kpoint_set = KpointSet(model, Ecut, fft_size, variational, kgrid,
symmetries, use_symmetries_for_kpoint_reduction,
comm_kpts, architecture)

# Setup FFT plans
fft_grid = FFTGrid(fft_size, model.unit_cell_volume, architecture)

# Compute k-point information and spread them across processors
# Right now we split only the kcoords: both spin channels have to be handled
# by the same process
n_procs = mpi_nprocs(comm_kpts)
n_kpt = length(kcoords_global)
n_irreducible_kpoints = n_kpt

# The code cannot handle MPI ranks without k-points. If there are more prcocesses
# than k-points, we duplicate k-points with the highest weight on the empty MPI
# ranks (and scale the weight accordingly)
if n_procs > n_kpt
for i in n_kpt+1:n_procs
idx = argmax(kweights_global)
kweights_global[idx] *= 0.5
push!(kweights_global, kweights_global[idx])
push!(kcoords_global, kcoords_global[idx])
end
@warn("Attempting to parallelize $n_kpt k-points over $n_procs MPI ranks. " *
"DFTK does not support processes empty of k-point. Some k-points were " *
"duplicated over the extra ranks with scaled weights.")
end
n_kpt = length(kcoords_global)

# get the slice of 1:n_kpt to be handled by this process
# Note: MPI ranks are 0-based
krange_allprocs1 = split_evenly(1:n_kpt, n_procs)
krange_thisproc1 = krange_allprocs1[1 + MPI.Comm_rank(comm_kpts)]
@assert mpi_sum(length(krange_thisproc1), comm_kpts) == n_kpt
@assert !isempty(krange_thisproc1)

# Setup k-point basis sets
!variational && @warn(
"Non-variational calculations are experimental. " *
"Not all features of DFTK may be supported or work as intended."
)
kpoints = build_kpoints(model, fft_size, kcoords_global[krange_thisproc1], Ecut;
variational, architecture)
# kpoints is now possibly twice the size of krange. Make things consistent
if model.n_spin_components == 1
kweights = kweights_global[krange_thisproc1]
krange_allprocs = [[range] for range in krange_allprocs1]
else
kweights = vcat(kweights_global[krange_thisproc1],
kweights_global[krange_thisproc1])
krange_allprocs = [[range, n_kpt .+ range] for range in krange_allprocs1]
end
krange_thisproc = krange_allprocs[1 + MPI.Comm_rank(comm_kpts)]
krange_thisproc_allspin = reduce(vcat, krange_thisproc)

@assert mpi_sum(sum(kweights), comm_kpts) ≈ model.n_spin_components
@assert length(kpoints) == length(kweights)
@assert length(kpoints) == sum(length, krange_thisproc)
@assert length(kpoints) == length( krange_thisproc_allspin)

if architecture isa GPU && Threads.nthreads() > 1
error("Can't mix multi-threading and GPU computations yet.")
end

dvol = model.unit_cell_volume ./ prod(fft_size)
terms = Vector{Any}(undef, length(model.term_types)) # Dummy terms array, filled below

basis = PlaneWaveBasis{T, value_type(T), Arch, typeof(fft_grid),
typeof(kpoints[1].G_vectors)}(
basis = PlaneWaveBasis{T, value_type(T), Arch, typeof(fft_grid), typeof(kpoint_set)}(
model, fft_size, dvol,
Ecut, variational,
fft_grid,
kpoints, kweights, kgrid,
kcoords_global, kweights_global, n_irreducible_kpoints,
comm_kpts, krange_thisproc, krange_allprocs, krange_thisproc_allspin,
fft_grid, kpoint_set,
architecture, symmetries, symmetries_respect_rgrid,
use_symmetries_for_kpoint_reduction, terms)

Expand Down Expand Up @@ -423,47 +354,6 @@ function krange_spin(basis::PlaneWaveBasis, spin::Integer)
(1 + (spin - 1) * n_kpts_per_spin):(spin * n_kpts_per_spin)
end

"""
Sum an array over kpoints, taking weights into account
"""
function weighted_ksum(basis::PlaneWaveBasis, array)
res = sum(basis.kweights .* array)
mpi_sum(res, basis.comm_kpts)
end

"""
Utilities to get information about the irreducible k-point mesh (in case of duplication)
Useful for I/O, where k-point information should not be duplicated
"""
function irreducible_kcoords_global(basis::PlaneWaveBasis)
# Assume that duplicated k-points are appended at the end of the kcoords array
basis.kcoords_global[1:basis.n_irreducible_kpoints]
end

function irreducible_kweights_global(basis::PlaneWaveBasis{T}) where {T}
function same_kpoint(i_irr, i_dupl)
maximum(abs, basis.kcoords_global[i_dupl]-basis.kcoords_global[i_irr]) < eps(T)
end

# Check that weights add up to 1 on entry (non spin doubled k-points)
@assert sum(basis.kweights_global) ≈ 1

# Assume that duplicated k-points are appended at the end of the kcoords array
irr_kweights = basis.kweights_global[1:basis.n_irreducible_kpoints]
for i_dupl = basis.n_irreducible_kpoints+1:length(basis.kweights_global)
for i_irr = 1:basis.n_irreducible_kpoints
if same_kpoint(i_irr, i_dupl)
irr_kweights[i_irr] += basis.kweights_global[i_dupl]
break
end
end
end

# Test that irreducible weight add up to 1 (non spin doubled k-points)
@assert sum(irr_kweights) ≈ 1
irr_kweights
end

"""
Gather the distributed ``k``-point data on the master process and return
it as a `PlaneWaveBasis`. On the other (non-master) processes `nothing` is returned.
Expand Down
Loading
Loading