Skip to content

Commit

Permalink
#11: initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
cwschilly committed Nov 11, 2024
1 parent bd81b6e commit 7afc92a
Showing 1 changed file with 53 additions and 141 deletions.
194 changes: 53 additions & 141 deletions tests/ops/ops_kokkos_level2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,79 +3,12 @@
#include "pressio/ops.hpp"
#include "ops_shared_level2.hpp"

#include <Kokkos_DualView.hpp>
#include <Kokkos_Core.hpp>

//-------------------------------------------
// Test implementation and utilities
// Test implementation
//-------------------------------------------

// Note: Thanks to this wrapper we can create "double expressions"
// and pass them to testing routines same as Kokkos::DualView.
template <typename DualViewType, typename ExprGen>
class Expression2DualViewAdapter {
public:
using dualview_type = DualViewType;
using device_view_type = typename dualview_type::t_dev;
using host_view_type = typename dualview_type::t_host;
using device_expr_type = decltype((std::declval<ExprGen>())(std::declval<device_view_type&>()));
using host_expr_type = decltype((std::declval<ExprGen>())(std::declval<host_view_type&>()));

public:
Expression2DualViewAdapter(DualViewType dual_view, ExprGen F)
: base_view(dual_view),
dev_view(base_view.view_device()),
host_view(base_view.view_host()),
// Important: expressions store references so we pass base views
// that are stored in the adapter (as opposed to temp instances)
expr_dev(F(dev_view)),
expr_host(F(host_view))
{}

auto extent(size_t i) const { return expr_dev.extent(i); }

auto& view_host() { return expr_host; } // for modification on host
auto& view_device() { return expr_dev; } // for product() on device

void sync_host() { base_view.sync_host(); };
void sync_device() { base_view.sync_device(); };
void modify_host() { base_view.modify_host(); };
void modify_device() { base_view.modify_device(); };

protected:
dualview_type base_view;
device_view_type dev_view;
host_view_type host_view;
device_expr_type expr_dev;
host_expr_type expr_host;
};

template <typename DView, typename EGen>
static auto make_adapter(DView dual_view, EGen expr_gen) {
return Expression2DualViewAdapter<DView, EGen>(dual_view, expr_gen);
}

template <typename DView>
auto span_adapter(DView dual_view, std::size_t index, std::size_t size) {
return make_adapter(dual_view, [index, size](auto view) {
return pressio::span(view, index, size);
});
}

template <typename DView2D>
auto diag_adapter(DView2D matrix) {
return make_adapter(matrix, [](auto mtx) {
return pressio::diagonal(mtx);
});
}

template <typename DView2D>
auto subspan_adapter(DView2D matrix, std::size_t i0, std::size_t ext0, std::size_t i1, std::size_t ext1) {
using range_t = std::pair<size_t, size_t>;
return make_adapter(matrix, [=](auto mtx) {
return pressio::subspan(mtx, range_t{ i0, i0 + ext0 }, range_t{ i1, i1 + ext1 });
});
}

struct kokkosFixture
: public ::testing::Test {

Expand All @@ -84,42 +17,44 @@ struct kokkosFixture
static constexpr auto alpha1 = ::pressio::Constants<double>::one();
static constexpr auto beta0 = alpha0;
static constexpr auto beta1 = alpha1;
static constexpr auto one = ::pressio::Constants<size_t>::one();

const size_t x_size = 3;
const size_t y_size = 4;
// plain views
Kokkos::DualView<double**> A{ "A", y_size, x_size };
Kokkos::DualView<double*> x{ "x", x_size };
Kokkos::DualView<double*> y{ "y", y_size };
Kokkos::DualView<double*> xt{ "xt", y_size };
Kokkos::DualView<double*> yt{ "yt", x_size };
Kokkos::View<double**> A{ "A", y_size, x_size };
Kokkos::View<double*> x{ "x", x_size };
Kokkos::View<double*> y{ "y", y_size };
Kokkos::View<double*> xt{ "xt", y_size };
Kokkos::View<double*> yt{ "yt", x_size };
// expression base (data views)
const size_t input_size_ext = (x_size > y_size ? x_size : y_size) + 2;
Kokkos::DualView<double*> x_span_base{ "x_span", input_size_ext };
Kokkos::DualView<double*> y_span_base{ "y_span", input_size_ext };
Kokkos::DualView<double**> x_diag_base{ "x_diag", x_size, x_size };
Kokkos::DualView<double**> xt_diag_base{ "xt_diag", y_size, y_size };
Kokkos::DualView<double**> y_diag_base{ "y_diag", y_size, y_size };
Kokkos::DualView<double**> yt_diag_base{ "yt_diag", x_size, x_size };
Kokkos::DualView<double**> A_subspan_base{ "A_subspan", y_size + 2, x_size + 2 };
Kokkos::View<double*> x_span_base{ "x_span", input_size_ext };
Kokkos::View<double*> y_span_base{ "y_span", input_size_ext };
Kokkos::View<double**> x_diag_base{ "x_diag", x_size, x_size };
Kokkos::View<double**> xt_diag_base{ "xt_diag", y_size, y_size };
Kokkos::View<double**> y_diag_base{ "y_diag", y_size, y_size };
Kokkos::View<double**> yt_diag_base{ "yt_diag", x_size, x_size };
Kokkos::View<double**> A_subspan_base{ "A_subspan", y_size + 2, x_size + 2 };
// expressions
auto x_span() { return span_adapter(x_span_base, 1, x_size); }
auto xt_span() { return span_adapter(x_span_base, 1, y_size); }
auto y_span() { return span_adapter(y_span_base, 1, y_size); }
auto yt_span() { return span_adapter(y_span_base, 1, x_size); }
auto x_diagonal() { return diag_adapter(x_diag_base); }
auto xt_diagonal() { return diag_adapter(xt_diag_base); }
auto y_diagonal() { return diag_adapter(y_diag_base); }
auto yt_diagonal() { return diag_adapter(yt_diag_base); }
auto A_subspan() { return subspan_adapter(A_subspan_base, 1, y_size, 1, x_size); }
auto x_span() { return pressio::span(x_span_base, one, x_size); }
auto xt_span() { return pressio::span(x_span_base, one, y_size); }
auto y_span() { return pressio::span(y_span_base, one, y_size); }
auto yt_span() { return pressio::span(y_span_base, one, x_size); }
auto x_diagonal() { return pressio::diagonal(x_diag_base); }
auto xt_diagonal() { return pressio::diagonal(xt_diag_base); }
auto y_diagonal() { return pressio::diagonal(y_diag_base); }
auto yt_diagonal() { return pressio::diagonal(yt_diag_base); }
auto A_subspan() {
using range_t = std::pair<size_t, size_t>;
return pressio::subspan(A_subspan_base, range_t{ one, one + y_size }, range_t{ one, one + x_size });
}

virtual void SetUp(){
auto A_h = A.view_host();
A_h(0, 0) = 1.; A_h(0, 1) = 0.; A_h(0 ,2) = 2.;
A_h(1, 0) = 2.; A_h(1, 1) = 1.; A_h(1, 2) = 3.;
A_h(2, 0) = 0.; A_h(2, 1) = 0.; A_h(2, 2) = 1.;
A_h(3, 0) = 2.; A_h(3, 1) = 3.; A_h(3, 2) = 4.;
A.modify_host();
A(0, 0) = 1.; A(0, 1) = 0.; A(0 ,2) = 2.;
A(1, 0) = 2.; A(1, 1) = 1.; A(1, 2) = 3.;
A(2, 0) = 0.; A(2, 1) = 0.; A(2, 2) = 1.;
A(3, 0) = 2.; A(3, 1) = 3.; A(3, 2) = 4.;
set_input(x, { 2., 6., 4. });
set_input(xt, { 4., 2., 6., 3. });
// expressions
Expand All @@ -145,24 +80,16 @@ struct kokkosFixture
Kokkos::deep_copy(x, x_h);
}

template <typename ...ViewProps>
static void set_input(Kokkos::DualView<ViewProps...> x, const std::vector<double> &values) {
set_input(x.view_device(), values);
x.modify_device();
x.sync_host();
}

// populates input matrix with unique integer values
template <typename ...ViewProps>
static void set_matrix(Kokkos::DualView<ViewProps...> mtx) {
auto mtx_h = mtx.view_host();
static void set_matrix(Kokkos::View<ViewProps...> mtx) {
auto mtx_h = mtx;
size_t ex0 = mtx.extent(0), ex1 = mtx.extent(1);
for (size_t i = 0; i < ex0; ++i) {
for (size_t j = 0; j < ex1; ++j) {
mtx_h(i, j) = (double)(i * ex1 + j + 1.0);
}
}
mtx.modify_host();
}

};
Expand All @@ -172,27 +99,20 @@ using ops_kokkos = kokkosFixture; // alias for nicer test naming
template <typename TransMode, typename AType, typename XType, typename YType, typename ScalarType>
void test_impl(TransMode trans, ScalarType alpha, AType A, XType x, ScalarType beta, YType y) {
// copy original values
Kokkos::View<double*, Kokkos::HostSpace> y_ref("y_ref", y.extent(0));
y.sync_host();
auto y_h = y.view_host(); // can't use deep_copy() because y can be [wrapped] Pressio expression
Kokkos::View<double*> y_ref("y_ref", y.extent(0));
for (size_t i = 0; i < y.extent(0); ++i) {
y_ref(i) = y_h(i);
y_ref(i) = y(i);
}

// call tested routine on device
A.sync_device();
x.sync_device();
y.sync_device();
// note: explicit instance needed here because we take ref in ::pressio::ops::product()
auto y_d = y.view_device();
pressio::ops::product(trans, alpha, A.view_device(), x.view_device(), beta, y_d);
y.modify_device();

// call reference gemv() on host
vanilla_gemv(trans, alpha, A.view_host(), x.view_host(), beta, y_ref);
y.sync_host();
for (size_t i = 0; i < y_h.extent(0); ++i) {
EXPECT_DOUBLE_EQ(y_h(i), y_ref(i));
pressio::ops::product(trans, alpha, A, x, beta, y);

// call reference gemv()
vanilla_gemv(trans, alpha, A, x, beta, y_ref);

// compare y and y_ref
for (size_t i = 0; i < y.extent(0); ++i) {
EXPECT_DOUBLE_EQ(y(i), y_ref(i));
}
}

Expand All @@ -201,32 +121,24 @@ void test_impl(TransMode trans, ScalarType alpha, AType A, XType x, ScalarType b
template <typename FixtureType, typename TransMode, typename AType, typename XType, typename YType>
void test_impl(const FixtureType &test, TransMode trans, AType A, XType x, YType y) {
// alpha = 1, beta = 0, simulate NaN injection in uninitialized y
y.sync_device();
::pressio::ops::fill(y.view_device(), test.NaN);
y.modify_device();
test_impl(trans, test.alpha1, test.A, x, test.beta0, y);
auto y_nan = y;
::pressio::ops::fill(y_nan, test.NaN);
test_impl(trans, test.alpha1, A, x, test.beta0, y_nan);

// alpha = 1, beta = 1, reuse values in y
test_impl(trans, test.alpha1, test.A, x, test.beta1, y);

// simulate NaN in input
A.sync_host();
auto A_h = A.view_host();
const auto original = A_h(0, 0);
A_h(0, 0) = test.NaN;
A.modify_host();
test_impl(trans, test.alpha1, A, x, test.beta1, y);

// alpha = 0, beta = 1, simulate NaN in input
test_impl(trans, test.alpha0, test.A, x, test.beta1, y);
const auto original = A(0, 0);
A(0, 0) = test.NaN;
test_impl(trans, test.alpha0, A, x, test.beta1, y);

// alpha = 0, beta = 0, NaN in both input and result
::pressio::ops::fill(y.view_device(), test.NaN);
y.modify_device();
test_impl(trans, test.alpha0, test.A, x, test.beta0, y);
::pressio::ops::fill(y_nan, test.NaN);
test_impl(trans, test.alpha0, A, x, test.beta0, y_nan);

// restore original A
A_h(0, 0) = original;
A.modify_host();
A(0, 0) = original;
}

//-------------------------------------------
Expand Down

0 comments on commit 7afc92a

Please sign in to comment.