diff --git a/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc index e5ac80fb2..1099877e9 100644 --- a/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc +++ b/src/atlas/interpolation/method/sphericalvector/SphericalVector.cc @@ -58,34 +58,49 @@ void sparseMatrixForEach(const MatrixT& matrix, const Functor& functor) { atlas_omp_parallel_for(auto k = 0; k < matrix.outerSize(); ++k) { for (auto it = typename MatrixT::InnerIterator(matrix, k); it; ++it) { - functor(it.value(), it.row(), it.col()); + functor(it.row(), it.col(), it.value()); } } } -template -void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView, - TargetView&& targetView, const Functor& multiplyFunctor) { +template +void sparseMatrixForEach(const MatrixT1& matrix1, const MatrixT2& matrix2, + const Functor& functor) { - sparseMatrixForEach(matrix, [&](const auto& weight, auto i, auto j) { + atlas_omp_parallel_for(auto k = 0; k < matrix1.outerSize(); ++k) { + for (auto[it1, it2] = + std::make_pair(typename MatrixT1::InnerIterator(matrix1, k), + typename MatrixT2::InnerIterator(matrix2, k)); + it1; ++it1, ++it2) { + functor(it1.row(), it1.col(), it1.value(), it2.value()); + } + } +} + +template +void matrixMultiply(const SourceView& sourceView, TargetView& targetView, + const Functor& multiplyFunctor, + const Matrices&... matrices) { + + sparseMatrixForEach(matrices..., [&](auto i, auto j, const auto&... weights) { constexpr auto Rank = std::decay_t::rank(); if constexpr (Rank == 2) { const auto sourceSlice = sourceView.slice(j, array::Range::all()); auto targetSlice = targetView.slice(i, array::Range::all()); - multiplyFunctor(weight, sourceSlice, targetSlice); - } - else if constexpr (Rank == 3) { + multiplyFunctor(sourceSlice, targetSlice, weights...); + } else if constexpr (Rank == 3) { const auto sourceSlice = sourceView.slice(j, array::Range::all(), array::Range::all()); auto targetSlice = targetView.slice(i, array::Range::all(), array::Range::all()); array::helpers::ArrayForEach<0>::apply( std::tie(sourceSlice, targetSlice), - [&](auto&&... slices) { multiplyFunctor(weight, slices...); }); - } - else { + [&](auto&& sourceVars, auto&& targetVars) { + multiplyFunctor(sourceVars, targetVars, weights...); + }); + } else { ATLAS_NOTIMPLEMENTED; } }); @@ -122,7 +137,7 @@ void SphericalVector::do_setup(const FunctionSpace& source, const auto sourceLonLats = array::make_view(source_.lonlat()); const auto targetLonLats = array::make_view(target_.lonlat()); - sparseMatrixForEach(realWeights, [&](const auto& weight, auto i, auto j) { + sparseMatrixForEach(realWeights, [&](auto i, auto j, const auto& weight) { const auto sourceLonLat = PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1)); @@ -214,35 +229,31 @@ void SphericalVector::interpolate_vector_field(const Field& sourceField, auto targetView = array::make_view(targetField); targetView.assign(0.); - const auto horizontalComponent = [](const auto& weight, auto&& sourceVars, - auto&& targetVars) { + const auto horizontalComponent = [](const auto& sourceVars, auto& targetVars, + const auto& complexWeight) { const auto sourceVector = Complex(sourceVars(0), sourceVars(1)); - const auto targetVector = weight * sourceVector; + const auto targetVector = complexWeight * sourceVector; targetVars(0) += targetVector.real(); targetVars(1) += targetVector.imag(); }; if (sourceField.variables() == 2) { - matrixMultiply(*complexWeights_, sourceView, targetView, - horizontalComponent); + matrixMultiply(sourceView, targetView, horizontalComponent, + *complexWeights_); return; } else if (sourceField.variables() == 3) { - const auto magnitudesArray = matrix().data(); - const auto* weightsBegin = complexWeights_->valuePtr(); - const auto weightMagnitude = [&](const auto& weight) { - const auto idx = std::distance(weightsBegin, &weight); - return magnitudesArray[idx]; - }; + const auto realWeights = makeMatrixMap(matrix()); const auto horizontalAndVerticalComponent = [&]( - const auto& weight, auto&& sourceVars, auto&& targetVars) { - horizontalComponent(weight, sourceVars, targetVars); - targetVars(2) += weightMagnitude(weight) * sourceVars(2); + const auto& sourceVars, auto& targetVars, const auto& complexWeight, + const auto& realWeight) { + horizontalComponent(sourceVars, targetVars, complexWeight); + targetVars(2) += realWeight * sourceVars(2); }; - matrixMultiply(*complexWeights_, sourceView, targetView, - horizontalAndVerticalComponent); + matrixMultiply(sourceView, targetView, horizontalAndVerticalComponent, + *complexWeights_, realWeights); return; }