diff --git a/src/TiledArray/expressions/binary_engine.h b/src/TiledArray/expressions/binary_engine.h index 93192e2b5e..411a1c7c13 100644 --- a/src/TiledArray/expressions/binary_engine.h +++ b/src/TiledArray/expressions/binary_engine.h @@ -204,8 +204,10 @@ class BinaryEngine : public ExprEngine { /// \param target_indices The target index list for this expression void perm_indices(const BipartiteIndexList& target_indices) { if (permute_tiles_) { - TA_ASSERT(left_.indices().size() == target_indices.size()); - TA_ASSERT(right_.indices().size() == target_indices.size()); + TA_ASSERT(left_.indices().size() == target_indices.size() || + (left_.indices().second().size() ^ target_indices.second().size())); + TA_ASSERT(right_.indices().size() == target_indices.size() || + (right_.indices().second().size() ^ target_indices.second().size())); init_indices_(target_indices); diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index 91924efeb2..9713e0b0df 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -407,6 +407,9 @@ class MultEngine : public ContEngine> { return op_type(op_base_type()); } else if (inner_prod == TensorProduct::Contraction) { return op_type(op_base_type(this->element_return_op_)); + } else if (inner_prod == TensorProduct::Scale) { + TA_ASSERT(this->product_type() == TensorProduct::Hadamard); + return op_type(op_base_type()); } else abort(); } else { // plain tensors @@ -432,6 +435,9 @@ class MultEngine : public ContEngine> { return op_type(op_base_type(), perm); } else if (inner_prod == TensorProduct::Contraction) { return op_type(op_base_type(this->element_return_op_), perm); + } else if (inner_prod == TensorProduct::Scale) { + TA_ASSERT(this->product_type() == TensorProduct::Hadamard); + return op_type(op_base_type(this->element_return_op_), perm); } else abort(); } else { // plain tensor diff --git a/tests/einsum.cpp b/tests/einsum.cpp index 37889a73f9..9ea4dd39d3 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -900,6 +900,98 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) { // tot_type out = einsum(lhs("i,j;m,n"), rhs("j,k"), "k,i,j;n,m"); } +BOOST_AUTO_TEST_CASE(ij_mn_eq_ji_mn_times_ij) { + using t_type = DistArray, SparsePolicy>; + using tot_type = DistArray>, SparsePolicy>; + using matrix_il = TiledArray::detail::matrix_il>; + auto& world = TiledArray::get_default_world(); + Tensor lhs_elem_0_0( + Range{7, 2}, {49, 73, 28, 46, 12, 83, 29, 61, 61, 98, 57, 28, 96, 57}); + Tensor lhs_elem_0_1( + Range{7, 2}, {78, 15, 69, 55, 87, 94, 28, 94, 79, 30, 26, 88, 48, 74}); + Tensor lhs_elem_1_0( + Range{7, 2}, {70, 32, 25, 71, 6, 56, 4, 13, 72, 50, 15, 95, 52, 89}); + Tensor lhs_elem_1_1( + Range{7, 2}, {12, 29, 17, 68, 37, 79, 5, 52, 13, 35, 53, 54, 78, 71}); + Tensor lhs_elem_2_0( + Range{7, 2}, {77, 39, 34, 94, 16, 82, 63, 27, 75, 12, 14, 59, 3, 14}); + Tensor lhs_elem_2_1( + Range{7, 2}, {65, 90, 37, 41, 65, 75, 59, 16, 44, 85, 86, 11, 40, 24}); + Tensor lhs_elem_3_0( + Range{7, 2}, {77, 53, 11, 6, 99, 63, 46, 68, 83, 56, 76, 86, 91, 79}); + Tensor lhs_elem_3_1( + Range{7, 2}, {56, 11, 33, 90, 36, 38, 33, 54, 60, 21, 16, 28, 6, 97}); + Tensor lhs_elem_4_0( + Range{7, 2}, {77, 53, 11, 6, 99, 63, 46, 68, 83, 56, 76, 86, 91, 79}); + Tensor lhs_elem_4_1( + Range{7, 2}, {56, 11, 33, 90, 36, 38, 33, 54, 60, 21, 16, 28, 6, 97}); + Tensor lhs_elem_5_0( + Range{7, 2}, {77, 53, 11, 6, 99, 63, 46, 68, 83, 56, 76, 86, 91, 79}); + Tensor lhs_elem_5_1( + Range{7, 2}, {56, 11, 33, 90, 36, 38, 33, 54, 60, 21, 16, 28, 6, 97}); + matrix_il lhs_il{{lhs_elem_0_0, lhs_elem_0_1}, + {lhs_elem_1_0, lhs_elem_1_1}, + {lhs_elem_2_0, lhs_elem_2_1}, + {lhs_elem_3_0, lhs_elem_3_1}, + {lhs_elem_4_0, lhs_elem_4_1}, + {lhs_elem_5_0, lhs_elem_5_1}}; + TiledRange lhs_trange{{0, 2, 6}, {0, 2}}; + tot_type lhs(world, lhs_trange, lhs_il); + + TiledRange rhs_trange{{0, 2}, {0, 2, 6}}; + t_type rhs(world, rhs_trange); + rhs.fill_random(); + + // + // i,j;m,n = j,i;n,m * i,j + // + TiledRange ref_result_trange{rhs_trange.dim(0), rhs_trange.dim(1)}; + tot_type ref_result(world, ref_result_trange); + + // why cannot lhs and rhs be captured by ref? + auto make_tile = [lhs, rhs](TA::Range const& rng) { + tot_type::value_type result_tile{rng}; + for (auto&& res_ix : result_tile.range()) { + auto i = res_ix[0]; + auto j = res_ix[1]; + + using Ix2 = std::array; + + auto lhs_tile_ix = lhs.trange().element_to_tile(Ix2{j, i}); + auto lhs_tile = lhs.find(lhs_tile_ix).get(/* dowork */ false); + + auto rhs_tile_ix = rhs.trange().element_to_tile(Ix2({i, j})); + auto rhs_tile = rhs.find(rhs_tile_ix).get(/* dowork */ false ); + + auto& res_el = + result_tile.at_ordinal(result_tile.range().ordinal(Ix2{i, j})); + auto const& lhs_el = + lhs_tile.at_ordinal(lhs_tile.range().ordinal(Ix2{j, i})); + auto rhs_el = + rhs_tile.at_ordinal(rhs_tile.range().ordinal(Ix2{i, j})); + res_el = tot_type::element_type( + lhs_el.scale(rhs_el), // scale + TiledArray::Permutation{0, 1} // permute + ); + } + return result_tile; + }; + + using std::begin; + using std::end; + + for (auto it = begin(ref_result); it != end(ref_result); ++it) { + auto tile = TA::get_default_world().taskq.add(make_tile, it.make_range()); + *it = tile; + } + + tot_type result; + BOOST_REQUIRE_NO_THROW(result("i,j;m,n") = lhs("j,i;m,n") * rhs("i,j")); + + const bool are_equal = ToTArrayFixture::are_equal(result, ref_result); + BOOST_CHECK(are_equal); +} + BOOST_AUTO_TEST_SUITE_END() // einsum_tot_t // Eigen einsum indices