Skip to content

Commit

Permalink
Update get_values and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Jul 27, 2023
1 parent a2b5394 commit a618e7f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
8 changes: 8 additions & 0 deletions core/test/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ TYPED_TEST(BatchMultiVector, KnowsItsSizeAndValues)
}


TYPED_TEST(BatchMultiVector, CanGetValuesForEntry)
{
using value_type = typename TestFixture::value_type;

ASSERT_EQ(this->mtx->get_values_for_entry(1)[0], value_type{1.0});
}


TYPED_TEST(BatchMultiVector, CanBeCopied)
{
auto mtx_copy = gko::BatchMultiVector<TypeParam>::create(this->exec);
Expand Down
29 changes: 26 additions & 3 deletions include/ginkgo/core/base/batch_multi_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,28 @@ class BatchMultiVector
*/
dim<2> get_common_size() const { return batch_size_.get_common_size(); }

/**
* Returns a pointer to the array of values of the multi-vector
*
* @return the pointer to the array of values
*/
value_type* get_values(size_type batch_id = 0) noexcept
{
return values_.get_data();
}

/**
* @copydoc get_values(size_type)
*
* @note This is the constant version of the function, which can be
* significantly more memory efficient than the non-constant version,
* so always prefer this version.
*/
const value_type* get_const_values() const noexcept
{
return values_.get_const_data();
}

/**
* Returns a pointer to the array of values of the multi-vector for a
* specific batch entry.
Expand All @@ -178,21 +200,22 @@ class BatchMultiVector
*
* @return the pointer to the array of values
*/
value_type* get_values(size_type batch_id = 0) noexcept
value_type* get_values_for_entry(size_type batch_id) noexcept
{
GKO_ASSERT(batch_id < this->get_num_batch_entries());
return values_.get_data() +
this->get_size().get_cumulative_offset(batch_id);
}

/**
* @copydoc get_values(size_type)
* @copydoc get_values_at_entry(size_type)
*
* @note This is the constant version of the function, which can be
* significantly more memory efficient than the non-constant version,
* so always prefer this version.
*/
const value_type* get_const_values(size_type batch_id = 0) const noexcept
const value_type* get_const_values_for_entry(
size_type batch_id) const noexcept
{
GKO_ASSERT(batch_id < this->get_num_batch_entries());
return values_.get_const_data() +
Expand Down
3 changes: 1 addition & 2 deletions test/test_install/test_install.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ int main()
// core/base/batch_dim.hpp
{
using type1 = int;
auto common_size = gko::dim<2>{4, 2};
auto test = gko::batch_dim<2, type1>{2, common_size};
auto test = gko::batch_dim<2, type1>{};
}

// core/base/batch_multi_vector.hpp
Expand Down

0 comments on commit a618e7f

Please sign in to comment.