Skip to content

Commit

Permalink
'einsum' function supports Hadamard+contraction on the outer indices …
Browse files Browse the repository at this point in the history
…with Hadamard or contraction on the inner indices.
  • Loading branch information
bimalgaudel committed Jan 14, 2024
1 parent 350041d commit 06d7736
Showing 1 changed file with 35 additions and 9 deletions.
44 changes: 35 additions & 9 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
struct {
std::string a, b, c;
// Hadamard, external, internal indices for inner tensor
Einsum::Index<std::string> h, e, i;
Einsum::Index<std::string> A, B, C, h, e, i;
} inner;
if constexpr (std::tuple_size<decltype(cs)>::value == 2) {
if constexpr (IsArrayToT<ArrayA>)
Expand All @@ -116,14 +116,14 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
inner.c = ";" + (std::string)std::get<1>(cs);

Einsum::Index<std::string> a_idx, b_idx, c_idx;
if constexpr (IsArrayToT<ArrayA>) a_idx = std::get<1>(Einsum::idx(A));
if constexpr (IsArrayToT<ArrayB>) b_idx = std::get<1>(Einsum::idx(B));
if constexpr (IsArrayToT<ArrayA>) inner.A = std::get<1>(Einsum::idx(A));
if constexpr (IsArrayToT<ArrayB>) inner.B = std::get<1>(Einsum::idx(B));
if constexpr (IsArrayToT<ArrayA> || IsArrayToT<ArrayB>)
c_idx = std::get<1>(cs);
inner.C = std::get<1>(cs);

inner.h = a_idx & b_idx & c_idx;
inner.e = (a_idx ^ b_idx);
inner.i = (a_idx & b_idx) - inner.h;
inner.h = inner.A & inner.B & inner.C;
inner.e = (inner.A ^ inner.B);
inner.i = (inner.A & inner.B) - inner.h;
}

// these are "Hadamard" (fused) indices
Expand Down Expand Up @@ -227,8 +227,34 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
ai = ai.reshape(shape, batch);
bi = bi.reshape(shape, batch);
for (size_t k = 0; k < batch; ++k) {
auto hk = ai.batch(k).dot(bi.batch(k));
tile({k}) += hk;
using Ix = ::Einsum::Index<std::string>;
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
TA_ASSERT(inner.h ^ inner.i &&
"Hadamard with contraction not supported between the "
"inner tensors");

auto aik = ai.batch(k);
auto bik = bi.batch(k);
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());

auto &el = tile({k});
using TensorT = std::remove_reference_t<decltype(el)>;

auto mult_op = [&inner](auto const &l, auto const &r) -> TensorT {
return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r,
inner.B, inner.C)
: TA::detail::tensor_contract(l, inner.A, r,
inner.B, inner.C);
};

for (auto i = 0; i < vol; ++i)
el.add_to(mult_op(aik.data()[i], bik.data()[i]));

} else {
auto hk = ai.batch(k).dot(bi.batch(k));
tile({k}) += hk;
}
}
}
auto pc = C.permutation;
Expand Down

0 comments on commit 06d7736

Please sign in to comment.