Skip to content

Commit

Permalink
Tidied fused loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Dec 1, 2023
1 parent 6107e91 commit 77856f3
Showing 1 changed file with 40 additions and 29 deletions.
69 changes: 40 additions & 29 deletions src/atlas/interpolation/method/sphericalvector/SphericalVector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,34 +58,49 @@ void sparseMatrixForEach(const MatrixT& matrix, const Functor& functor) {

atlas_omp_parallel_for(auto k = 0; k < matrix.outerSize(); ++k) {
for (auto it = typename MatrixT::InnerIterator(matrix, k); it; ++it) {
functor(it.value(), it.row(), it.col());
functor(it.row(), it.col(), it.value());
}
}
}

template <typename MatrixT, typename SourceView, typename TargetView,
typename Functor>
void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView,
TargetView&& targetView, const Functor& multiplyFunctor) {
template <typename MatrixT1, typename MatrixT2, typename Functor>
void sparseMatrixForEach(const MatrixT1& matrix1, const MatrixT2& matrix2,
const Functor& functor) {

sparseMatrixForEach(matrix, [&](const auto& weight, auto i, auto j) {
atlas_omp_parallel_for(auto k = 0; k < matrix1.outerSize(); ++k) {
for (auto[it1, it2] =
std::make_pair(typename MatrixT1::InnerIterator(matrix1, k),
typename MatrixT2::InnerIterator(matrix2, k));
it1; ++it1, ++it2) {
functor(it1.row(), it1.col(), it1.value(), it2.value());
}
}
}

template <typename SourceView, typename TargetView, typename Functor,
typename... Matrices>
void matrixMultiply(const SourceView& sourceView, TargetView& targetView,
const Functor& multiplyFunctor,
const Matrices&... matrices) {

sparseMatrixForEach(matrices..., [&](auto i, auto j, const auto&... weights) {

constexpr auto Rank = std::decay_t<SourceView>::rank();
if constexpr (Rank == 2) {
const auto sourceSlice = sourceView.slice(j, array::Range::all());
auto targetSlice = targetView.slice(i, array::Range::all());
multiplyFunctor(weight, sourceSlice, targetSlice);
}
else if constexpr (Rank == 3) {
multiplyFunctor(sourceSlice, targetSlice, weights...);
} else if constexpr (Rank == 3) {
const auto sourceSlice =
sourceView.slice(j, array::Range::all(), array::Range::all());
auto targetSlice =
targetView.slice(i, array::Range::all(), array::Range::all());
array::helpers::ArrayForEach<0>::apply(
std::tie(sourceSlice, targetSlice),
[&](auto&&... slices) { multiplyFunctor(weight, slices...); });
}
else {
[&](auto&& sourceVars, auto&& targetVars) {
multiplyFunctor(sourceVars, targetVars, weights...);
});
} else {
ATLAS_NOTIMPLEMENTED;
}
});
Expand Down Expand Up @@ -122,7 +137,7 @@ void SphericalVector::do_setup(const FunctionSpace& source,
const auto sourceLonLats = array::make_view<double, 2>(source_.lonlat());
const auto targetLonLats = array::make_view<double, 2>(target_.lonlat());

sparseMatrixForEach(realWeights, [&](const auto& weight, auto i, auto j) {
sparseMatrixForEach(realWeights, [&](auto i, auto j, const auto& weight) {

const auto sourceLonLat =
PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1));
Expand Down Expand Up @@ -214,35 +229,31 @@ void SphericalVector::interpolate_vector_field(const Field& sourceField,
auto targetView = array::make_view<Value, Rank>(targetField);
targetView.assign(0.);

const auto horizontalComponent = [](const auto& weight, auto&& sourceVars,
auto&& targetVars) {
const auto horizontalComponent = [](const auto& sourceVars, auto& targetVars,
const auto& complexWeight) {
const auto sourceVector = Complex(sourceVars(0), sourceVars(1));
const auto targetVector = weight * sourceVector;
const auto targetVector = complexWeight * sourceVector;
targetVars(0) += targetVector.real();
targetVars(1) += targetVector.imag();
};

if (sourceField.variables() == 2) {
matrixMultiply(*complexWeights_, sourceView, targetView,
horizontalComponent);
matrixMultiply(sourceView, targetView, horizontalComponent,
*complexWeights_);
return;
} else if (sourceField.variables() == 3) {

const auto magnitudesArray = matrix().data();
const auto* weightsBegin = complexWeights_->valuePtr();
const auto weightMagnitude = [&](const auto& weight) {
const auto idx = std::distance(weightsBegin, &weight);
return magnitudesArray[idx];
};
const auto realWeights = makeMatrixMap(matrix());

const auto horizontalAndVerticalComponent = [&](
const auto& weight, auto&& sourceVars, auto&& targetVars) {
horizontalComponent(weight, sourceVars, targetVars);
targetVars(2) += weightMagnitude(weight) * sourceVars(2);
const auto& sourceVars, auto& targetVars, const auto& complexWeight,
const auto& realWeight) {
horizontalComponent(sourceVars, targetVars, complexWeight);
targetVars(2) += realWeight * sourceVars(2);
};

matrixMultiply(*complexWeights_, sourceView, targetView,
horizontalAndVerticalComponent);
matrixMultiply(sourceView, targetView, horizontalAndVerticalComponent,
*complexWeights_, realWeights);

return;
}
Expand Down

0 comments on commit 77856f3

Please sign in to comment.