Skip to content

Commit

Permalink
rename make_complex, use absolute_type only...etc
Browse files Browse the repository at this point in the history
1. remove/add_complex
2. to_real/complex alias
3. using absolute_type not outplace_absolute_type
4. remove unneeded .get() in test

Co-authored-by: Tobias Ribizel <[email protected]>
  • Loading branch information
yhmtsai and upsj committed Sep 15, 2020
1 parent 95b1edc commit 7d9a37a
Show file tree
Hide file tree
Showing 37 changed files with 134 additions and 101 deletions.
6 changes: 3 additions & 3 deletions core/matrix/coo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ void Coo<ValueType, IndexType>::compute_absolute_inplace()


template <typename ValueType, typename IndexType>
std::unique_ptr<typename Coo<ValueType, IndexType>::outplace_absolute_type>
std::unique_ptr<typename Coo<ValueType, IndexType>::absolute_type>
Coo<ValueType, IndexType>::compute_absolute() const
{
auto exec = this->get_executor();

auto abs_coo = outplace_absolute_type::create(
exec, this->get_size(), this->get_num_stored_elements());
auto abs_coo = absolute_type::create(exec, this->get_size(),
this->get_num_stored_elements());

abs_coo->col_idxs_ = col_idxs_;
abs_coo->row_idxs_ = row_idxs_;
Expand Down
6 changes: 3 additions & 3 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,13 @@ void Csr<ValueType, IndexType>::compute_absolute_inplace()


template <typename ValueType, typename IndexType>
std::unique_ptr<typename Csr<ValueType, IndexType>::outplace_absolute_type>
std::unique_ptr<typename Csr<ValueType, IndexType>::absolute_type>
Csr<ValueType, IndexType>::compute_absolute() const
{
auto exec = this->get_executor();

auto abs_csr = outplace_absolute_type::create(
exec, this->get_size(), this->get_num_stored_elements());
auto abs_csr = absolute_type::create(exec, this->get_size(),
this->get_num_stored_elements());

abs_csr->col_idxs_ = col_idxs_;
abs_csr->row_ptrs_ = row_ptrs_;
Expand Down
4 changes: 2 additions & 2 deletions core/matrix/dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,13 +764,13 @@ void Dense<ValueType>::compute_absolute_inplace()


template <typename ValueType>
std::unique_ptr<typename Dense<ValueType>::outplace_absolute_type>
std::unique_ptr<typename Dense<ValueType>::absolute_type>
Dense<ValueType>::compute_absolute() const
{
auto exec = this->get_executor();

// do not inherit the stride
auto abs_dense = outplace_absolute_type::create(exec, this->get_size());
auto abs_dense = absolute_type::create(exec, this->get_size());

exec->run(dense::make_outplace_absolute_dense(this, abs_dense.get()));

Expand Down
5 changes: 2 additions & 3 deletions core/matrix/diagonal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,12 @@ void Diagonal<ValueType>::compute_absolute_inplace()


template <typename ValueType>
std::unique_ptr<typename Diagonal<ValueType>::outplace_absolute_type>
std::unique_ptr<typename Diagonal<ValueType>::absolute_type>
Diagonal<ValueType>::compute_absolute() const
{
auto exec = this->get_executor();

auto abs_diagonal =
outplace_absolute_type::create(exec, this->get_size()[0]);
auto abs_diagonal = absolute_type::create(exec, this->get_size()[0]);

exec->run(diagonal::make_outplace_absolute_array(
this->get_const_values(), this->get_size()[0],
Expand Down
4 changes: 2 additions & 2 deletions core/matrix/ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,12 @@ void Ell<ValueType, IndexType>::compute_absolute_inplace()


template <typename ValueType, typename IndexType>
std::unique_ptr<typename Ell<ValueType, IndexType>::outplace_absolute_type>
std::unique_ptr<typename Ell<ValueType, IndexType>::absolute_type>
Ell<ValueType, IndexType>::compute_absolute() const
{
auto exec = this->get_executor();

auto abs_ell = outplace_absolute_type::create(
auto abs_ell = absolute_type::create(
exec, this->get_size(), this->get_num_stored_elements_per_row(),
this->get_stride());

Expand Down
4 changes: 2 additions & 2 deletions core/matrix/hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,13 @@ void Hybrid<ValueType, IndexType>::compute_absolute_inplace()


template <typename ValueType, typename IndexType>
std::unique_ptr<typename Hybrid<ValueType, IndexType>::outplace_absolute_type>
std::unique_ptr<typename Hybrid<ValueType, IndexType>::absolute_type>
Hybrid<ValueType, IndexType>::compute_absolute() const
{
auto exec = this->get_executor();

// use default strategy
auto abs_hybrid = outplace_absolute_type::create(exec, this->get_size());
auto abs_hybrid = absolute_type::create(exec, this->get_size());

abs_hybrid->ell_->copy_from(ell_->compute_absolute());
abs_hybrid->coo_->copy_from(coo_->compute_absolute());
Expand Down
4 changes: 2 additions & 2 deletions core/matrix/sellp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,12 @@ void Sellp<ValueType, IndexType>::compute_absolute_inplace()


template <typename ValueType, typename IndexType>
std::unique_ptr<typename Sellp<ValueType, IndexType>::outplace_absolute_type>
std::unique_ptr<typename Sellp<ValueType, IndexType>::absolute_type>
Sellp<ValueType, IndexType>::compute_absolute() const
{
auto exec = this->get_executor();

auto abs_sellp = outplace_absolute_type::create(
auto abs_sellp = absolute_type::create(
exec, this->get_size(), this->get_slice_size(),
this->get_stride_factor(), this->get_total_cols());

Expand Down
6 changes: 3 additions & 3 deletions core/test/base/lin_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ class DummyLinOpWithType
public gko::EnableCreateMethod<DummyLinOpWithType<Type>>,
public gko::EnableAbsoluteComputation<DummyLinOpWithType<Type>> {
public:
using outplace_absolute_type = gko::remove_complex<DummyLinOpWithType>;
using absolute_type = gko::remove_complex<DummyLinOpWithType>;
DummyLinOpWithType(std::shared_ptr<const gko::Executor> exec)
: gko::EnableLinOp<DummyLinOpWithType>(exec)
{}
Expand All @@ -338,9 +338,9 @@ class DummyLinOpWithType

void compute_absolute_inplace() override { value_ = gko::abs(value_); }

std::unique_ptr<outplace_absolute_type> compute_absolute() const override
std::unique_ptr<absolute_type> compute_absolute() const override
{
return std::make_unique<outplace_absolute_type>(
return std::make_unique<absolute_type>(
this->get_executor(), this->get_size(), gko::abs(value_));
}

Expand Down
4 changes: 2 additions & 2 deletions cuda/test/matrix/coo_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ TEST_F(Coo, InplaceAbsoluteMatrixIsEquivalentToRef)
mtx->compute_absolute_inplace();
dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(mtx.get(), dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 1e-14);
}


Expand All @@ -289,7 +289,7 @@ TEST_F(Coo, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_mtx = mtx->compute_absolute();
auto dabs_mtx = dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand Down
8 changes: 4 additions & 4 deletions cuda/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ TEST_F(Csr, InplaceAbsoluteMatrixIsEquivalentToRef)
mtx->compute_absolute_inplace();
dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(mtx.get(), dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 1e-14);
}


Expand All @@ -758,7 +758,7 @@ TEST_F(Csr, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_mtx = mtx->compute_absolute();
auto dabs_mtx = dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand All @@ -769,7 +769,7 @@ TEST_F(Csr, InplaceAbsoluteComplexMatrixIsEquivalentToRef)
complex_mtx->compute_absolute_inplace();
complex_dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(complex_mtx.get(), complex_dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(complex_mtx, complex_dmtx, 1e-14);
}


Expand All @@ -780,7 +780,7 @@ TEST_F(Csr, OutplaceAbsoluteComplexMatrixIsEquivalentToRef)
auto abs_mtx = complex_mtx->compute_absolute();
auto dabs_mtx = complex_dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand Down
4 changes: 2 additions & 2 deletions cuda/test/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ TEST_F(Dense, InplaceAbsoluteMatrixIsEquivalentToRef)
x->compute_absolute_inplace();
dx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(x.get(), dx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(x, dx, 1e-14);
}


Expand All @@ -605,7 +605,7 @@ TEST_F(Dense, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_x = x->compute_absolute();
auto dabs_x = dx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_x.get(), dabs_x.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_x, dabs_x, 1e-14);
}


Expand Down
4 changes: 2 additions & 2 deletions cuda/test/matrix/diagonal_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ TEST_F(Diagonal, InplaceAbsoluteMatrixIsEquivalentToRef)
diag->compute_absolute_inplace();
ddiag->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(diag.get(), ddiag.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(diag, ddiag, 1e-14);
}


Expand All @@ -267,7 +267,7 @@ TEST_F(Diagonal, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_diag = diag->compute_absolute();
auto dabs_diag = ddiag->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_diag.get(), dabs_diag.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_diag, dabs_diag, 1e-14);
}


Expand Down
4 changes: 2 additions & 2 deletions cuda/test/matrix/ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ TEST_F(Ell, InplaceAbsoluteMatrixIsEquivalentToRef)
mtx->compute_absolute_inplace();
dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(mtx.get(), dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 1e-14);
}


Expand All @@ -374,7 +374,7 @@ TEST_F(Ell, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_mtx = mtx->compute_absolute();
auto dabs_mtx = dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand Down
4 changes: 2 additions & 2 deletions cuda/test/matrix/hybrid_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ TEST_F(Hybrid, InplaceAbsoluteMatrixIsEquivalentToRef)
mtx->compute_absolute_inplace();
dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(mtx.get(), dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 1e-14);
}


Expand All @@ -249,7 +249,7 @@ TEST_F(Hybrid, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_mtx = mtx->compute_absolute();
auto dabs_mtx = dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand Down
4 changes: 2 additions & 2 deletions cuda/test/matrix/sellp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ TEST_F(Sellp, InplaceAbsoluteMatrixIsEquivalentToRef)
mtx->compute_absolute_inplace();
dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(mtx.get(), dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 1e-14);
}


Expand All @@ -368,7 +368,7 @@ TEST_F(Sellp, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_mtx = mtx->compute_absolute();
auto dabs_mtx = dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand Down
4 changes: 2 additions & 2 deletions hip/test/matrix/coo_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ TEST_F(Coo, InplaceAbsoluteMatrixIsEquivalentToRef)
mtx->compute_absolute_inplace();
dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(mtx.get(), dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 1e-14);
}


Expand All @@ -289,7 +289,7 @@ TEST_F(Coo, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_mtx = mtx->compute_absolute();
auto dabs_mtx = dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand Down
8 changes: 4 additions & 4 deletions hip/test/matrix/csr_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ TEST_F(Csr, InplaceAbsoluteMatrixIsEquivalentToRef)
mtx->compute_absolute_inplace();
dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(mtx.get(), dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 1e-14);
}


Expand All @@ -743,7 +743,7 @@ TEST_F(Csr, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_mtx = mtx->compute_absolute();
auto dabs_mtx = dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand All @@ -754,7 +754,7 @@ TEST_F(Csr, InplaceAbsoluteComplexMatrixIsEquivalentToRef)
complex_mtx->compute_absolute_inplace();
complex_dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(complex_mtx.get(), complex_dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(complex_mtx, complex_dmtx, 1e-14);
}


Expand All @@ -765,7 +765,7 @@ TEST_F(Csr, OutplaceAbsoluteComplexMatrixIsEquivalentToRef)
auto abs_mtx = complex_mtx->compute_absolute();
auto dabs_mtx = complex_dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand Down
4 changes: 2 additions & 2 deletions hip/test/matrix/dense_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ TEST_F(Dense, InplaceAbsoluteMatrixIsEquivalentToRef)
x->compute_absolute_inplace();
dx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(x.get(), dx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(x, dx, 1e-14);
}


Expand All @@ -587,7 +587,7 @@ TEST_F(Dense, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_x = x->compute_absolute();
auto dabs_x = dx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_x.get(), dabs_x.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_x, dabs_x, 1e-14);
}


Expand Down
4 changes: 2 additions & 2 deletions hip/test/matrix/diagonal_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ TEST_F(Diagonal, InplaceAbsoluteMatrixIsEquivalentToRef)
diag->compute_absolute_inplace();
ddiag->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(diag.get(), ddiag.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(diag, ddiag, 1e-14);
}


Expand All @@ -267,7 +267,7 @@ TEST_F(Diagonal, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_diag = diag->compute_absolute();
auto dabs_diag = ddiag->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_diag.get(), dabs_diag.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_diag, dabs_diag, 1e-14);
}


Expand Down
4 changes: 2 additions & 2 deletions hip/test/matrix/ell_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ TEST_F(Ell, InplaceAbsoluteMatrixIsEquivalentToRef)
mtx->compute_absolute_inplace();
dmtx->compute_absolute_inplace();

GKO_ASSERT_MTX_NEAR(mtx.get(), dmtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 1e-14);
}


Expand All @@ -374,7 +374,7 @@ TEST_F(Ell, OutplaceAbsoluteMatrixIsEquivalentToRef)
auto abs_mtx = mtx->compute_absolute();
auto dabs_mtx = dmtx->compute_absolute();

GKO_ASSERT_MTX_NEAR(abs_mtx.get(), dabs_mtx.get(), 1e-14);
GKO_ASSERT_MTX_NEAR(abs_mtx, dabs_mtx, 1e-14);
}


Expand Down
Loading

0 comments on commit 7d9a37a

Please sign in to comment.