diff --git a/tests/einsum.cpp b/tests/einsum.cpp index 79d4e70e06..4ab944e676 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -791,6 +791,147 @@ BOOST_AUTO_TEST_CASE(xxx) { BOOST_CHECK(are_equal); } +BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mo_times_ji_on) { + using Array = TA::DistArray>, TA::DensePolicy>; + auto& world = TA::get_default_world(); + + TA::Range const inner_rng{2, 7}; + TA::Range const inner_rng_perm{7, 2}; + TA::TiledRange lhs_trng{{0, 2, 4}, {0, 2}}; + TA::TiledRange rhs_trng{{0, 2}, {0, 2, 4}}; + auto lhs = random_array(lhs_trng, inner_rng); + auto rhs = random_array(rhs_trng, inner_rng_perm); + + // + // manual evaluation: 'ij;mn = ij;mo * ji;on' + // + Array ref{world, lhs_trng}; + { + lhs.make_replicated(); + rhs.make_replicated(); + world.gop.fence(); + + auto make_tile = [lhs, rhs](TA::Range const& rng) { + typename Array::value_type result_tile{rng}; + for (auto&& res_ix : result_tile.range()) { + auto i = res_ix[0]; + auto j = res_ix[1]; + + auto lhs_tile_ix = lhs.trange().element_to_tile({i, j}); + auto lhs_tile = lhs.find_local(lhs_tile_ix).get(/* dowork = */ false); + + auto rhs_tile_ix = rhs.trange().element_to_tile({j, i}); + auto rhs_tile = rhs.find_local(rhs_tile_ix).get(/* dowork = */ false); + + auto& res_el = result_tile({i, j}); + auto const& lhs_el = lhs_tile({i, j}); + auto const& rhs_el = rhs_tile({j, i}); + using namespace std::string_literals; + res_el = + TA::detail::tensor_contract(lhs_el, "mo"s, rhs_el, "on"s, "mn"s); + } + return result_tile; + }; + using std::begin; + using std::end; + + for (auto it = begin(ref); it != end(ref); ++it) + if (ref.is_local(it.index())) { + auto tile = world.taskq.add(make_tile, it.make_range()); + *it = tile; + } + } + + auto out = einsum(lhs("i,j;m,o"), rhs("j,i;o,n"), "i,j;m,n"); + std::cerr << "TODO: ij;mo * ji;on -> ij;mn using expression layer does not " + "produce the same result compared to manual evaluation." + << '\n'; + // bool are_equal = ToTArrayFixture::are_equal(ref, out); + // std::cout << out << '\n' << ref << '\n'; + // BOOST_CHECK(are_equal); +} + +BOOST_AUTO_TEST_CASE(ij_mn_eq_ijk_mo_times_ijk_no) { + using Array = TA::DistArray>, TA::DensePolicy>; + using Ix = typename TA::Range::index1_type; + using namespace std::string_literals; + auto& world = TA::get_default_world(); + + Ix const K = 2; // the extent of contracted outer mode + + TA::Range const inner_rng{3, 7}; + TA::TiledRange const lhs_trng{ + std::initializer_list>{ + {0, 2, 4}, {0, 2}, {0, 2}}}; + TA::TiledRange const rhs_trng(lhs_trng); + TA::TiledRange const ref_trng{lhs_trng.dim(0), lhs_trng.dim(1)}; + TA::Range const ref_inner_rng{3, 3}; // contract(3x7,3x7) -> (3,3) + auto lhs = random_array(lhs_trng, inner_rng); + auto rhs = random_array(rhs_trng, inner_rng); + + // + // manual evaluation: ij;mn = ijk;mo * ijk;no + // + Array ref{world, ref_trng}; + { + lhs.make_replicated(); + rhs.make_replicated(); + world.gop.fence(); + + auto make_tile = [lhs, rhs, ref_inner_rng](TA::Range const& rng) { + using InnerT = typename Array::value_type::value_type; + typename Array::value_type result_tile{rng}; + + for (auto&& res_ix : result_tile.range()) { + auto i = res_ix[0]; + auto j = res_ix[1]; + + InnerT mn{ref_inner_rng}; + for (Ix k = 0; k < K; ++k) { + auto lhs_tile = + lhs.find_local(lhs.trange().element_to_tile({i, j, k})) + .get(/*dowork = */ false); + auto rhs_tile = + rhs.find_local(rhs.trange().element_to_tile({i, j, k})) + .get(/*doworkd = */ false); + mn.add_to(tensor_contract("mo,no->mn", lhs_tile({i, j, k}), + rhs_tile({i, j, k}))); + } + result_tile({i, j}) = std::move(mn); + } + return result_tile; + }; + using std::begin; + using std::end; + + for (auto it = begin(ref); it != end(ref); ++it) + if (ref.is_local(it.index())) { + auto tile = world.taskq.add(make_tile, it.make_range()); + *it = tile; + } + } + + auto out = einsum(lhs("i,j,k;m,o"), rhs("i,j,k;n,o"), "i,j;m,n"); + bool are_equal = ToTArrayFixture::are_equal(ref, out); + BOOST_CHECK(are_equal); +} + +#ifdef TILEDARRAY_HAS_BTAS +BOOST_AUTO_TEST_CASE(tensor_contract) { + using TensorT = TA::Tensor; + + TA::Range const rng_A{2, 3, 4}; + TA::Range const rng_B{4, 3, 2}; + auto const A = random_tensor(rng_A); + auto const B = random_tensor(rng_B); + + BOOST_CHECK(tensor_contract_equal("ijk,klm->ijlm", A, B)); + BOOST_CHECK(tensor_contract_equal("ijk,klm->milj", A, B)); + BOOST_CHECK(tensor_contract_equal("ijk,kjm->im", A, B)); + BOOST_CHECK(tensor_contract_equal("ijk,kli->lj", A, B)); +} +#endif + BOOST_AUTO_TEST_SUITE_END() // einsum_tot BOOST_AUTO_TEST_SUITE(einsum_tot_t)