diff --git a/core/base/device_matrix_data.cpp b/core/base/device_matrix_data.cpp index a2e5d6e7044..adbd5af8e60 100644 --- a/core/base/device_matrix_data.cpp +++ b/core/base/device_matrix_data.cpp @@ -93,6 +93,15 @@ device_matrix_data::create_from_host( } +template +void device_matrix_data::fill_zero() +{ + row_idxs_.fill(0); + col_idxs_.fill(0); + values_.fill(ValueType{0}); +} + + template void device_matrix_data::sort_row_major() { diff --git a/include/ginkgo/core/base/device_matrix_data.hpp b/include/ginkgo/core/base/device_matrix_data.hpp index 35e3f300954..16a68517b2a 100644 --- a/include/ginkgo/core/base/device_matrix_data.hpp +++ b/include/ginkgo/core/base/device_matrix_data.hpp @@ -114,6 +114,11 @@ class device_matrix_data { static device_matrix_data create_from_host( std::shared_ptr exec, const host_type& data); + /** + * Fills the matrix entries with zeros + */ + void fill_zero(); + /** * Sorts the matrix entries in row-major order * This means that they will be sorted by row index first, and then by diff --git a/test/base/device_matrix_data_kernels.cpp b/test/base/device_matrix_data_kernels.cpp index ffadbcfb245..6ddc926b76c 100644 --- a/test/base/device_matrix_data_kernels.cpp +++ b/test/base/device_matrix_data_kernels.cpp @@ -241,6 +241,28 @@ TYPED_TEST(DeviceMatrixData, CopiesToHost) } +TYPED_TEST(DeviceMatrixData, CanFillEntriesWithZeros) +{ + using value_type = typename TestFixture::value_type; + using index_type = typename TestFixture::index_type; + using device_matrix_data = gko::device_matrix_data; + auto device_data = device_matrix_data{this->exec, gko::dim<2>{4, 3}, 10}; + + device_data.fill_zero(); + + auto arrays = device_data.empty_out(); + auto expected_row_idxs = gko::array(this->exec, 10); + auto expected_col_idxs = gko::array(this->exec, 10); + auto expected_values = gko::array(this->exec, 10); + expected_row_idxs.fill(0); + expected_col_idxs.fill(0); + expected_values.fill(0.0); + GKO_ASSERT_ARRAY_EQ(arrays.row_idxs, expected_row_idxs); + GKO_ASSERT_ARRAY_EQ(arrays.col_idxs, expected_col_idxs); + GKO_ASSERT_ARRAY_EQ(arrays.values, expected_values); +} + + TYPED_TEST(DeviceMatrixData, SortsRowMajor) { using value_type = typename TestFixture::value_type;