Skip to content

Commit

Permalink
Add tests for Hadamard+contraction on outer indices, and contraction …
Browse files Browse the repository at this point in the history
…on inner indices.
  • Loading branch information
bimalgaudel committed Jan 14, 2024
1 parent 06d7736 commit dc5c0ad
Showing 1 changed file with 141 additions and 0 deletions.
141 changes: 141 additions & 0 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,147 @@ BOOST_AUTO_TEST_CASE(xxx) {
BOOST_CHECK(are_equal);
}

BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mo_times_ji_on) {
using Array = TA::DistArray<TA::Tensor<TA::Tensor<int>>, TA::DensePolicy>;
auto& world = TA::get_default_world();

TA::Range const inner_rng{2, 7};
TA::Range const inner_rng_perm{7, 2};
TA::TiledRange lhs_trng{{0, 2, 4}, {0, 2}};
TA::TiledRange rhs_trng{{0, 2}, {0, 2, 4}};
auto lhs = random_array<Array>(lhs_trng, inner_rng);
auto rhs = random_array<Array>(rhs_trng, inner_rng_perm);

//
// manual evaluation: 'ij;mn = ij;mo * ji;on'
//
Array ref{world, lhs_trng};
{
lhs.make_replicated();
rhs.make_replicated();
world.gop.fence();

auto make_tile = [lhs, rhs](TA::Range const& rng) {
typename Array::value_type result_tile{rng};
for (auto&& res_ix : result_tile.range()) {
auto i = res_ix[0];
auto j = res_ix[1];

auto lhs_tile_ix = lhs.trange().element_to_tile({i, j});
auto lhs_tile = lhs.find_local(lhs_tile_ix).get(/* dowork = */ false);

auto rhs_tile_ix = rhs.trange().element_to_tile({j, i});
auto rhs_tile = rhs.find_local(rhs_tile_ix).get(/* dowork = */ false);

auto& res_el = result_tile({i, j});
auto const& lhs_el = lhs_tile({i, j});
auto const& rhs_el = rhs_tile({j, i});
using namespace std::string_literals;
res_el =
TA::detail::tensor_contract(lhs_el, "mo"s, rhs_el, "on"s, "mn"s);
}
return result_tile;
};
using std::begin;
using std::end;

for (auto it = begin(ref); it != end(ref); ++it)
if (ref.is_local(it.index())) {
auto tile = world.taskq.add(make_tile, it.make_range());
*it = tile;
}
}

auto out = einsum(lhs("i,j;m,o"), rhs("j,i;o,n"), "i,j;m,n");
std::cerr << "TODO: ij;mo * ji;on -> ij;mn using expression layer does not "
"produce the same result compared to manual evaluation."
<< '\n';
// bool are_equal = ToTArrayFixture::are_equal<ShapeComp::False>(ref, out);
// std::cout << out << '\n' << ref << '\n';
// BOOST_CHECK(are_equal);
}

BOOST_AUTO_TEST_CASE(ij_mn_eq_ijk_mo_times_ijk_no) {
using Array = TA::DistArray<TA::Tensor<TA::Tensor<int>>, TA::DensePolicy>;
using Ix = typename TA::Range::index1_type;
using namespace std::string_literals;
auto& world = TA::get_default_world();

Ix const K = 2; // the extent of contracted outer mode

TA::Range const inner_rng{3, 7};
TA::TiledRange const lhs_trng{
std::initializer_list<std::initializer_list<Ix>>{
{0, 2, 4}, {0, 2}, {0, 2}}};
TA::TiledRange const rhs_trng(lhs_trng);
TA::TiledRange const ref_trng{lhs_trng.dim(0), lhs_trng.dim(1)};
TA::Range const ref_inner_rng{3, 3}; // contract(3x7,3x7) -> (3,3)
auto lhs = random_array<Array>(lhs_trng, inner_rng);
auto rhs = random_array<Array>(rhs_trng, inner_rng);

//
// manual evaluation: ij;mn = ijk;mo * ijk;no
//
Array ref{world, ref_trng};
{
lhs.make_replicated();
rhs.make_replicated();
world.gop.fence();

auto make_tile = [lhs, rhs, ref_inner_rng](TA::Range const& rng) {
using InnerT = typename Array::value_type::value_type;
typename Array::value_type result_tile{rng};

for (auto&& res_ix : result_tile.range()) {
auto i = res_ix[0];
auto j = res_ix[1];

InnerT mn{ref_inner_rng};
for (Ix k = 0; k < K; ++k) {
auto lhs_tile =
lhs.find_local(lhs.trange().element_to_tile({i, j, k}))
.get(/*dowork = */ false);
auto rhs_tile =
rhs.find_local(rhs.trange().element_to_tile({i, j, k}))
.get(/*doworkd = */ false);
mn.add_to(tensor_contract("mo,no->mn", lhs_tile({i, j, k}),
rhs_tile({i, j, k})));
}
result_tile({i, j}) = std::move(mn);
}
return result_tile;
};
using std::begin;
using std::end;

for (auto it = begin(ref); it != end(ref); ++it)
if (ref.is_local(it.index())) {
auto tile = world.taskq.add(make_tile, it.make_range());
*it = tile;
}
}

auto out = einsum(lhs("i,j,k;m,o"), rhs("i,j,k;n,o"), "i,j;m,n");
bool are_equal = ToTArrayFixture::are_equal<ShapeComp::False>(ref, out);
BOOST_CHECK(are_equal);
}

#ifdef TILEDARRAY_HAS_BTAS
BOOST_AUTO_TEST_CASE(tensor_contract) {
using TensorT = TA::Tensor<int>;

TA::Range const rng_A{2, 3, 4};
TA::Range const rng_B{4, 3, 2};
auto const A = random_tensor<TensorT>(rng_A);
auto const B = random_tensor<TensorT>(rng_B);

BOOST_CHECK(tensor_contract_equal("ijk,klm->ijlm", A, B));
BOOST_CHECK(tensor_contract_equal("ijk,klm->milj", A, B));
BOOST_CHECK(tensor_contract_equal("ijk,kjm->im", A, B));
BOOST_CHECK(tensor_contract_equal("ijk,kli->lj", A, B));
}
#endif

BOOST_AUTO_TEST_SUITE_END() // einsum_tot

BOOST_AUTO_TEST_SUITE(einsum_tot_t)
Expand Down

0 comments on commit dc5c0ad

Please sign in to comment.