Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global interpolation matrix --> MPI-distributed interpolation matrices #258

Open
wants to merge 25 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3a421d8
test setup to code: from global matrix with the two function spaces t…
sbrdar Feb 5, 2025
8d46f76
prepare assemble distribute for more cases
sbrdar Feb 5, 2025
e38190e
setup the distribution of the global matrix - to be finalised still
sbrdar Feb 7, 2025
02406c4
allow make_view on zero size matrix
wdeconinck Feb 7, 2025
963b749
implement distribute_global_matrix up to halo
sbrdar Feb 7, 2025
4063277
add Intepolation(Config, FunctionSpace, FunctionSpace, Cache) and fid…
sbrdar Feb 11, 2025
87fb2de
fix the unit test to compare the global matrix and the distributed in…
sbrdar Feb 12, 2025
a2579e1
1) fix more bugs in the distribution of the global matrix; 2) make su…
sbrdar Feb 12, 2025
bf47bf3
improve unit test for distribution of global matrices; comment out te…
sbrdar Feb 12, 2025
261e6da
comment out the intermediate check
sbrdar Feb 12, 2025
9dc524b
fix usage of Cache in different interpolation methods
sbrdar Feb 13, 2025
a40953e
keep the unit test under 5 sec exection time
sbrdar Feb 13, 2025
83f0dde
clean up
sbrdar Feb 13, 2025
ef6ddbd
make sandbox tool -atlas-global-matrix- work parallel when reading in…
sbrdar Feb 13, 2025
44b5f4b
cleanup and but fix in the output from the interpolation
sbrdar Feb 13, 2025
9577c62
make default naming work, i.e. Both work: ./bin/atlas-global-matrix a…
sbrdar Feb 13, 2025
14a4658
cleanup
sbrdar Feb 13, 2025
07838c7
add timers in sandbox tool atlas-global-matrix when in the reading-in…
sbrdar Feb 13, 2025
dda030d
cleanup in atlas-global-matrix
sbrdar Feb 13, 2025
ccdb644
another cleanup in atlas-global-matrix
sbrdar Feb 13, 2025
5bd8055
fix interpolation setup from cache for the nearest-neighbour class
sbrdar Feb 14, 2025
5a058ec
make the knn methods and grid-box-method work again in the atlas-glob…
sbrdar Feb 14, 2025
efc5b6c
1) use unordered_map in the distribution of the global matrix; 2) add…
wdeconinck Feb 14, 2025
4cb4448
remove warnings
sbrdar Feb 18, 2025
b673b30
fix for the review #256
sbrdar Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 142 additions & 6 deletions src/atlas/interpolation/AssembleGlobalMatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@

#include "AssembleGlobalMatrix.h"

#include <unordered_map>

#include "eckit/linalg/types.h"

#include "atlas/array.h"
#include "atlas/linalg/sparse/SparseMatrixToTriplets.h"
#include "atlas/functionspace/StructuredColumns.h"
#include "atlas/interpolation/Cache.h"

namespace atlas::interpolation {

atlas::linalg::SparseMatrixStorage assemble_global_matrix(const Interpolation& interpolation, int mpi_root) {
linalg::SparseMatrixStorage assemble_global_matrix(const Interpolation& interpolation, int mpi_root) {

auto src_fs = interpolation.source();
auto tgt_fs = interpolation.target();
Expand Down Expand Up @@ -50,10 +53,10 @@ atlas::linalg::SparseMatrixStorage assemble_global_matrix(const Interpolation& i
return src_remote_index(idx) + one_if_structuredcolumns;
};

const auto src_global_index = array::make_view<gidx_t,1>(src_fs.global_index());
const auto tgt_global_index = array::make_view<gidx_t,1>(tgt_fs.global_index());
const auto src_part = array::make_view<int,1>(src_fs.partition());
const auto tgt_ghost = array::make_view<int,1>(tgt_fs.ghost());
const auto src_global_index = array::make_view<gidx_t, 1>(src_fs.global_index());
const auto tgt_global_index = array::make_view<gidx_t, 1>(tgt_fs.global_index());
const auto src_part = array::make_view<int, 1>(src_fs.partition());
const auto tgt_ghost = array::make_view<int, 1>(tgt_fs.ghost());

eckit::mpi::Buffer<gidx_t> recv_global_idx_buf(mpi_size);
{
Expand Down Expand Up @@ -149,7 +152,7 @@ atlas::linalg::SparseMatrixStorage assemble_global_matrix(const Interpolation& i
gidx_t tgt_max_gidx = compute_max_global_index(tgt_fs);
gidx_t src_max_gidx = compute_max_global_index(src_fs);

atlas::linalg::SparseMatrixStorage global_matrix;
linalg::SparseMatrixStorage global_matrix;
if (mpi_rank == mpi_root) {
size_t nrows = tgt_max_gidx;
size_t ncols = src_max_gidx;
Expand All @@ -160,6 +163,139 @@ atlas::linalg::SparseMatrixStorage assemble_global_matrix(const Interpolation& i
return global_matrix;
}

template <typename ViewValue, typename ViewIndex, typename Value, typename Index>
void distribute_global_matrix(const linalg::SparseMatrixView<ViewValue,ViewIndex>& global_matrix,
const array::Array& partition, std::vector<Index>& rows, std::vector<Index>& cols, std::vector<Value>& vals, int mpi_root) {
ATLAS_TRACE("distribute_global_matrix_lowlevel");

const auto tgt_part_glb = array::make_view<int,1>(partition);

auto& mpi_comm = mpi::comm();
auto mpi_size = mpi_comm.size();
auto mpi_rank = mpi_comm.rank();

// compute how many nnz-entries each task gets
size_t nnz_loc = 0;
int mpi_tag = 0;
std::vector<std::size_t> nnz_per_task(mpi_size);
const auto outer = global_matrix.outer();
const auto inner = global_matrix.inner();
const auto value = global_matrix.value();
if (mpi_rank == mpi_root) {

for(std::size_t r = 0; r < global_matrix.rows(); ++r) {
nnz_per_task[tgt_part_glb(r)] += outer[r+1] - outer[r];
}
for (int jproc = 0; jproc < mpi::comm().size(); ++jproc) {
if (jproc != mpi_root) {
mpi::comm().send(nnz_per_task.data() + jproc, 1, jproc, mpi_tag);
}
}
nnz_loc = nnz_per_task[mpi_root];
}
else {
mpi_comm.receive(&nnz_loc, 1, mpi_root, mpi_tag);
}

rows.resize(nnz_loc);
cols.resize(nnz_loc);
vals.resize(nnz_loc);

if (mpi_rank == mpi_root) {
std::vector<std::vector<Index>> send_rows(mpi_size);
std::vector<std::vector<Index>> send_cols(mpi_size);
std::vector<std::vector<Value>> send_vals(mpi_size);
for(std::size_t jproc=0; jproc < mpi_size; ++jproc) {
send_rows[jproc].reserve(nnz_per_task[jproc]);
send_cols[jproc].reserve(nnz_per_task[jproc]);
send_vals[jproc].reserve(nnz_per_task[jproc]);
}
for(std::size_t r = 0; r < global_matrix.rows(); ++r) {
int jproc = tgt_part_glb(r);
for (auto c = outer[r]; c < outer[r + 1]; ++c) {
auto col = inner[c];
send_rows[jproc].emplace_back(r);
send_cols[jproc].emplace_back(col);
send_vals[jproc].emplace_back(value[c]);
}
}
for(std::size_t jproc = 0; jproc < mpi_size; ++jproc) {
if (jproc != mpi_root) {
mpi_comm.send(send_rows[jproc].data(), send_rows[jproc].size(), jproc, mpi_tag);
mpi_comm.send(send_cols[jproc].data(), send_cols[jproc].size(), jproc, mpi_tag);
mpi_comm.send(send_vals[jproc].data(), send_vals[jproc].size(), jproc, mpi_tag);
}
else {
rows = send_rows[jproc];
cols = send_cols[jproc];
vals = send_vals[jproc];
}
}
}
else {
mpi_comm.receive(rows.data(), nnz_loc, mpi_root, mpi_tag);
mpi_comm.receive(cols.data(), nnz_loc, mpi_root, mpi_tag);
mpi_comm.receive(vals.data(), nnz_loc, mpi_root, mpi_tag);
}
}

linalg::SparseMatrixStorage distribute_global_matrix(const FunctionSpace& src_fs, const FunctionSpace& tgt_fs, const linalg::SparseMatrixStorage& gmatrix, int mpi_root) {
ATLAS_TRACE("distribute_global_matrix");
Field field_tgt_part_glb = tgt_fs.createField(tgt_fs.partition(), option::global(mpi_root));
ATLAS_TRACE_SCOPE("gather partition") {
tgt_fs.gather(tgt_fs.partition(), field_tgt_part_glb);
}

using Index = eckit::linalg::Index;
using Value = eckit::linalg::Scalar;
std::vector<Index> rows, cols;
std::vector<Value> vals;
distribute_global_matrix(atlas::linalg::make_host_view<Value, Index>(gmatrix), field_tgt_part_glb, rows, cols, vals, mpi_root);

// map global index to local index
std::unordered_map<gidx_t, idx_t> to_local_rows;
std::unordered_map<gidx_t, idx_t> to_local_cols;

ATLAS_TRACE_SCOPE("convert to local indexing") {
auto tgt_gidx_exchanged = tgt_fs.createField(tgt_fs.global_index());
tgt_gidx_exchanged.array().copy(tgt_fs.global_index());
tgt_fs.haloExchange(tgt_gidx_exchanged);
const auto tgt_global_index = array::make_view<gidx_t, 1>(tgt_gidx_exchanged);
const auto tgt_ghost = array::make_view<int,1>(tgt_fs.ghost());

auto src_gidx_exchanged = src_fs.createField(src_fs.global_index());
src_gidx_exchanged.array().copy(src_fs.global_index());
src_fs.haloExchange(src_gidx_exchanged);
const auto src_global_index = array::make_view<gidx_t, 1>(src_gidx_exchanged);
const auto src_ghost = array::make_view<int,1>(src_fs.ghost());

for (idx_t r = 0; r < tgt_global_index.size(); ++r) {
auto gr = tgt_global_index(r);
if (tgt_ghost(r) && to_local_rows.find(gr) != to_local_rows.end()) {
continue;
}
to_local_rows[gr] = r;
}
for (idx_t c = 0; c < src_global_index.size(); ++c) {
auto gc = src_global_index(c);
if (src_ghost(c) && to_local_cols.find(gc) != to_local_cols.end()) {
continue;
}
to_local_cols[gc] = c;
}
for (int r = 0; r < rows.size(); ++r) {
rows[r] = to_local_rows[rows[r] + 1];
cols[r] = to_local_cols[cols[r] + 1];
}
}

linalg::SparseMatrixStorage matrix;
constexpr int index_base = 0;
constexpr bool is_sorted = false;
matrix = linalg::make_sparse_matrix_storage_from_rows_columns_values(tgt_fs.size(), src_fs.size(), rows, cols, vals, index_base, is_sorted);

return matrix;
}


} //end namespace
Expand Down
2 changes: 2 additions & 0 deletions src/atlas/interpolation/AssembleGlobalMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ namespace atlas::interpolation {

atlas::linalg::SparseMatrixStorage assemble_global_matrix(const Interpolation& interpolation, int mpi_root = 0);

atlas::linalg::SparseMatrixStorage distribute_global_matrix(const FunctionSpace& src_fs, const FunctionSpace& tgt_fs, const linalg::SparseMatrixStorage&, int mpi_root = 0);

} // namespace atlas::interpolation
10 changes: 10 additions & 0 deletions src/atlas/interpolation/Interpolation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ Interpolation::Interpolation(const Interpolation::Config& config, const Grid& so
return impl;
}()) {}

Interpolation::Interpolation(const Interpolation::Config& config, const FunctionSpace& fs_in, const FunctionSpace& fs_out,
const Interpolation::Cache& cache):
Handle([&]() -> Implementation* {
std::string type;
ATLAS_ASSERT(config.get("type", type));
Implementation* impl = interpolation::MethodFactory::build(type, config);
impl->setup(fs_in, fs_out, cache);
return impl;
}()) {}

extern "C" {
Interpolation::Implementation* atlas__Interpolation__new(const eckit::Parametrisation* config,
const functionspace::FunctionSpaceImpl* source,
Expand Down
2 changes: 2 additions & 0 deletions src/atlas/interpolation/Interpolation.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class Interpolation : DOXYGEN_HIDE(public util::ObjectHandle<interpolation::Meth

Interpolation(const Config&, const Grid& source, const Grid& target, const Cache&) noexcept(false);

Interpolation(const Config&, const FunctionSpace& source, const FunctionSpace& target, const Cache&) noexcept(false);

friend std::ostream& operator<<(std::ostream& out, const Interpolation& i) {
i.print(out);
return out;
Expand Down
5 changes: 5 additions & 0 deletions src/atlas/interpolation/method/Method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ void Method::setup(const Grid& source, const Grid& target, const Cache& cache) {
this->do_setup(source, target, cache);
}

void Method::setup(const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
ATLAS_TRACE("atlas::interpolation::method::Method::setup(FunctionSpace, FunctionSpace, Cache)");
this->do_setup(source, target, cache);
}

Method::Metadata Method::execute(const FieldSet& source, FieldSet& target) const {
ATLAS_TRACE("atlas::interpolation::method::Method::execute(FieldSet, FieldSet)");
Metadata metadata;
Expand Down
2 changes: 2 additions & 0 deletions src/atlas/interpolation/method/Method.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Method : public util::Object {
void setup(const FunctionSpace& source, const Field& target);
void setup(const FunctionSpace& source, const FieldSet& target);
void setup(const Grid& source, const Grid& target, const Cache&);
void setup(const FunctionSpace& source, const FunctionSpace& target, const Cache&);

Metadata execute(const FieldSet& source, FieldSet& target) const;
Metadata execute(const Field& source, Field& target) const;
Expand Down Expand Up @@ -136,6 +137,7 @@ class Method : public util::Object {

virtual void do_setup(const FunctionSpace& source, const FunctionSpace& target) = 0;
virtual void do_setup(const Grid& source, const Grid& target, const Cache&) = 0;
virtual void do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache&) = 0;
virtual void do_setup(const FunctionSpace& source, const Field& target);
virtual void do_setup(const FunctionSpace& source, const FieldSet& target);

Expand Down
6 changes: 6 additions & 0 deletions src/atlas/interpolation/method/binning/Binning.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ void Binning::do_setup(const Grid& source,
ATLAS_NOTIMPLEMENTED;
}

void Binning::do_setup(const FunctionSpace& source,
const FunctionSpace& target,
const Cache&) {
ATLAS_NOTIMPLEMENTED;
}


void Binning::do_setup(const FunctionSpace& source,
const FunctionSpace& target) {
Expand Down
1 change: 1 addition & 0 deletions src/atlas/interpolation/method/binning/Binning.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Binning : public Method {
void do_setup(const FunctionSpace& source,
const FunctionSpace& target) override;
void do_setup(const Grid& source, const Grid& target, const Cache&) override;
void do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache&) override;

std::vector<double> getAreaWeights(const FunctionSpace& source) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ void CubedSphereBilinear::do_setup(const Grid& source, const Grid& target, const
ATLAS_NOTIMPLEMENTED;
}

void CubedSphereBilinear::do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache&) {
ATLAS_NOTIMPLEMENTED;
}

void CubedSphereBilinear::do_setup(const FunctionSpace& source, const FunctionSpace& target) {
source_ = source;
target_ = target;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class CubedSphereBilinear : public Method {
using Method::do_setup;
void do_setup(const FunctionSpace& source, const FunctionSpace& target) override;
void do_setup(const Grid& source, const Grid& target, const Cache&) override;
void do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache&) override;

FunctionSpace source_;
FunctionSpace target_;
Expand Down
15 changes: 15 additions & 0 deletions src/atlas/interpolation/method/knn/GridBoxMethod.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,21 @@ void GridBoxMethod::do_setup(const Grid& source, const Grid& target, const Cache
}
}

void GridBoxMethod::do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
ATLAS_TRACE("GridBoxMethod::setup()");

if (not matrixFree_ && interpolation::MatrixCache(cache)) {
setMatrix(cache);
source_ = source;
target_ = target;
ATLAS_ASSERT(matrix().rows() == target.size());
ATLAS_ASSERT(matrix().cols() == source.size());
return;
}

Log::warning() << "Can not create GridBoxMethod from (FunctionSpace, FunctionSpace, Cache). Use (Grid, Grid, Cache)";
ATLAS_NOTIMPLEMENTED;
}

void GridBoxMethod::giveUp(const std::forward_list<size_t>& failures) {
Log::warning() << "Failed to intersect grid boxes: ";
Expand Down
1 change: 1 addition & 0 deletions src/atlas/interpolation/method/knn/GridBoxMethod.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class GridBoxMethod : public KNearestNeighboursBase {
*/
virtual void do_setup(const FunctionSpace& source, const FunctionSpace& target) override;
virtual void do_setup(const Grid& source, const Grid& target, const Cache&) override;
virtual void do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache&) override;

virtual const FunctionSpace& source() const override { return source_; }
virtual const FunctionSpace& target() const override { return target_; }
Expand Down
10 changes: 9 additions & 1 deletion src/atlas/interpolation/method/knn/KNearestNeighbours.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,19 @@ void KNearestNeighbours::do_setup(const Grid& source, const Grid& target, const
do_setup(functionspace(source), functionspace(target));
}

void KNearestNeighbours::do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
if (interpolation::MatrixCache(cache)) {
setMatrix(cache);
source_ = source;
target_ = target;
buildPointSearchTree(source);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building the PointSearchTree should not be needed when the cache contains a matrix.

On the other hand when there's no matrix in the cache this sets up nothing. That needs fixing.

if(interpolation::MatrixCache(cache)) { ... } 
else { 
  THIS SITUATION
}

}
}

void KNearestNeighbours::do_setup(const FunctionSpace& source, const FunctionSpace& target) {
source_ = source;
target_ = target;

// build point-search tree
buildPointSearchTree(source);

array::ArrayView<double, 2> lonlat = array::make_view<double, 2>(target.lonlat());
Expand Down
1 change: 1 addition & 0 deletions src/atlas/interpolation/method/knn/KNearestNeighbours.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class KNearestNeighbours : public KNearestNeighboursBase {
using KNearestNeighboursBase::do_setup;
virtual void do_setup(const FunctionSpace& source, const FunctionSpace& target) override;
virtual void do_setup(const Grid& source, const Grid& target, const Cache&) override;
virtual void do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache&) override;

FunctionSpace source_;
FunctionSpace target_;
Expand Down
2 changes: 0 additions & 2 deletions src/atlas/interpolation/method/knn/KNearestNeighboursBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ void KNearestNeighboursBase::buildPointSearchTree(Mesh& meshSource, const mesh::
ATLAS_TRACE();
eckit::TraceTimer<Atlas> tim("KNearestNeighboursBase::buildPointSearchTree()");


auto lonlat = array::make_view<double, 2>(meshSource.nodes().lonlat());
auto halo = array::make_view<int, 1>(meshSource.nodes().halo());
int h = _halo.size();
Expand All @@ -48,7 +47,6 @@ void KNearestNeighboursBase::buildPointSearchTree(Mesh& meshSource, const mesh::
}
pTree_.build();


// // generate 3D point coordinates
// mesh::actions::BuildXYZField("xyz")(meshSource);
}
Expand Down
9 changes: 9 additions & 0 deletions src/atlas/interpolation/method/knn/NearestNeighbour.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ MethodBuilder<NearestNeighbour> __builder("nearest-neighbour");

} // namespace

void NearestNeighbour::do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache& cache) {
if (interpolation::MatrixCache(cache)) {
setMatrix(cache);
source_ = source;
target_ = target;
buildPointSearchTree(source);
}
}

void NearestNeighbour::do_setup(const Grid& source, const Grid& target, const Cache&) {
if (mpi::size() > 1) {
ATLAS_NOTIMPLEMENTED;
Expand Down
2 changes: 1 addition & 1 deletion src/atlas/interpolation/method/knn/NearestNeighbour.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class NearestNeighbour : public KNearestNeighboursBase {
* @param target functionspace containing target points
*/
virtual void do_setup(const FunctionSpace& source, const FunctionSpace& target) override;

virtual void do_setup(const Grid& source, const Grid& target, const Cache&) override;
virtual void do_setup(const FunctionSpace& source, const FunctionSpace& target, const Cache&) override;

FunctionSpace source_;
FunctionSpace target_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ void SphericalVector::do_setup(const Grid& source, const Grid& target,
ATLAS_NOTIMPLEMENTED;
}

void SphericalVector::do_setup(const FunctionSpace& source, const FunctionSpace& target,
const Cache&) {
ATLAS_NOTIMPLEMENTED;
}

void SphericalVector::do_setup(const FunctionSpace& source,
const FunctionSpace& target) {
ATLAS_TRACE("interpolation::method::SphericalVector::do_setup");
Expand Down
Loading
Loading