Skip to content

Commit

Permalink
Support for pure hadamard product between a tot and a t: 'i,j;m,n * i…
Browse files Browse the repository at this point in the history
…,j -> i,j;m,n'
  • Loading branch information
bimalgaudel committed Dec 7, 2023
1 parent b75b1fc commit f8d4100
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/TiledArray/expressions/binary_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ class BinaryEngine : public ExprEngine<Derived> {
/// \param target_indices The target index list for this expression
void perm_indices(const BipartiteIndexList& target_indices) {
if (permute_tiles_) {
TA_ASSERT(left_.indices().size() == target_indices.size());
TA_ASSERT(right_.indices().size() == target_indices.size());
TA_ASSERT(left_.indices().size() == target_indices.size() ||
(left_.indices().second().size() ^ target_indices.second().size()));
TA_ASSERT(right_.indices().size() == target_indices.size() ||
(right_.indices().second().size() ^ target_indices.second().size()));

init_indices_<TensorProduct::Hadamard>(target_indices);

Expand Down
6 changes: 6 additions & 0 deletions src/TiledArray/expressions/mult_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,9 @@ class MultEngine : public ContEngine<MultEngine<Left, Right, Result>> {
return op_type(op_base_type());
} else if (inner_prod == TensorProduct::Contraction) {
return op_type(op_base_type(this->element_return_op_));
} else if (inner_prod == TensorProduct::Scale) {
TA_ASSERT(this->product_type() == TensorProduct::Hadamard);
return op_type(op_base_type());
} else
abort();
} else { // plain tensors
Expand All @@ -432,6 +435,9 @@ class MultEngine : public ContEngine<MultEngine<Left, Right, Result>> {
return op_type(op_base_type(), perm);
} else if (inner_prod == TensorProduct::Contraction) {
return op_type(op_base_type(this->element_return_op_), perm);
} else if (inner_prod == TensorProduct::Scale) {
TA_ASSERT(this->product_type() == TensorProduct::Hadamard);
return op_type(op_base_type(this->element_return_op_), perm);
} else
abort();
} else { // plain tensor
Expand Down
92 changes: 92 additions & 0 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,98 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
// tot_type out = einsum(lhs("i,j;m,n"), rhs("j,k"), "k,i,j;n,m");
}

BOOST_AUTO_TEST_CASE(ij_mn_eq_ji_mn_times_ij) {
using t_type = DistArray<Tensor<double>, SparsePolicy>;
using tot_type = DistArray<Tensor<Tensor<double>>, SparsePolicy>;
using matrix_il = TiledArray::detail::matrix_il<Tensor<double>>;
auto& world = TiledArray::get_default_world();
Tensor<double> lhs_elem_0_0(
Range{7, 2}, {49, 73, 28, 46, 12, 83, 29, 61, 61, 98, 57, 28, 96, 57});
Tensor<double> lhs_elem_0_1(
Range{7, 2}, {78, 15, 69, 55, 87, 94, 28, 94, 79, 30, 26, 88, 48, 74});
Tensor<double> lhs_elem_1_0(
Range{7, 2}, {70, 32, 25, 71, 6, 56, 4, 13, 72, 50, 15, 95, 52, 89});
Tensor<double> lhs_elem_1_1(
Range{7, 2}, {12, 29, 17, 68, 37, 79, 5, 52, 13, 35, 53, 54, 78, 71});
Tensor<double> lhs_elem_2_0(
Range{7, 2}, {77, 39, 34, 94, 16, 82, 63, 27, 75, 12, 14, 59, 3, 14});
Tensor<double> lhs_elem_2_1(
Range{7, 2}, {65, 90, 37, 41, 65, 75, 59, 16, 44, 85, 86, 11, 40, 24});
Tensor<double> lhs_elem_3_0(
Range{7, 2}, {77, 53, 11, 6, 99, 63, 46, 68, 83, 56, 76, 86, 91, 79});
Tensor<double> lhs_elem_3_1(
Range{7, 2}, {56, 11, 33, 90, 36, 38, 33, 54, 60, 21, 16, 28, 6, 97});
Tensor<double> lhs_elem_4_0(
Range{7, 2}, {77, 53, 11, 6, 99, 63, 46, 68, 83, 56, 76, 86, 91, 79});
Tensor<double> lhs_elem_4_1(
Range{7, 2}, {56, 11, 33, 90, 36, 38, 33, 54, 60, 21, 16, 28, 6, 97});
Tensor<double> lhs_elem_5_0(
Range{7, 2}, {77, 53, 11, 6, 99, 63, 46, 68, 83, 56, 76, 86, 91, 79});
Tensor<double> lhs_elem_5_1(
Range{7, 2}, {56, 11, 33, 90, 36, 38, 33, 54, 60, 21, 16, 28, 6, 97});
matrix_il lhs_il{{lhs_elem_0_0, lhs_elem_0_1},
{lhs_elem_1_0, lhs_elem_1_1},
{lhs_elem_2_0, lhs_elem_2_1},
{lhs_elem_3_0, lhs_elem_3_1},
{lhs_elem_4_0, lhs_elem_4_1},
{lhs_elem_5_0, lhs_elem_5_1}};
TiledRange lhs_trange{{0, 2, 6}, {0, 2}};
tot_type lhs(world, lhs_trange, lhs_il);

TiledRange rhs_trange{{0, 2}, {0, 2, 6}};
t_type rhs(world, rhs_trange);
rhs.fill_random();

//
// i,j;m,n = j,i;n,m * i,j
//
TiledRange ref_result_trange{rhs_trange.dim(0), rhs_trange.dim(1)};
tot_type ref_result(world, ref_result_trange);

// why cannot lhs and rhs be captured by ref?
auto make_tile = [lhs, rhs](TA::Range const& rng) {
tot_type::value_type result_tile{rng};
for (auto&& res_ix : result_tile.range()) {
auto i = res_ix[0];
auto j = res_ix[1];

using Ix2 = std::array<decltype(i), 2>;

auto lhs_tile_ix = lhs.trange().element_to_tile(Ix2{j, i});
auto lhs_tile = lhs.find(lhs_tile_ix).get(/* dowork */ false);

auto rhs_tile_ix = rhs.trange().element_to_tile(Ix2({i, j}));
auto rhs_tile = rhs.find(rhs_tile_ix).get(/* dowork */ false );

auto& res_el =
result_tile.at_ordinal(result_tile.range().ordinal(Ix2{i, j}));
auto const& lhs_el =
lhs_tile.at_ordinal(lhs_tile.range().ordinal(Ix2{j, i}));
auto rhs_el =
rhs_tile.at_ordinal(rhs_tile.range().ordinal(Ix2{i, j}));
res_el = tot_type::element_type(
lhs_el.scale(rhs_el), // scale
TiledArray::Permutation{0, 1} // permute
);
}
return result_tile;
};

using std::begin;
using std::end;

for (auto it = begin(ref_result); it != end(ref_result); ++it) {
auto tile = TA::get_default_world().taskq.add(make_tile, it.make_range());
*it = tile;
}

tot_type result;
BOOST_REQUIRE_NO_THROW(result("i,j;m,n") = lhs("j,i;m,n") * rhs("i,j"));

const bool are_equal = ToTArrayFixture::are_equal(result, ref_result);
BOOST_CHECK(are_equal);
}

BOOST_AUTO_TEST_SUITE_END() // einsum_tot_t

// Eigen einsum indices
Expand Down

0 comments on commit f8d4100

Please sign in to comment.