diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 9a1cb9f5f9..5ec69c7d0d 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -158,9 +158,10 @@ class ContEngine : public BinaryEngine { TensorProduct inner_product_type() const { TA_ASSERT(inner_product_type_ != TensorProduct::Invalid); // init_indices() must initialize this - /// only Hadamard and contraction are supported now + /// only Hadamard, contraction, and scale are supported now TA_ASSERT(inner_product_type_ == TensorProduct::Hadamard || - inner_product_type_ == TensorProduct::Contraction); + inner_product_type_ == TensorProduct::Contraction || + inner_product_type_ == TensorProduct::Scale); return inner_product_type_; } @@ -473,7 +474,8 @@ class ContEngine : public BinaryEngine { result_tile_type, left_tile_type, right_tile_type>; const auto inner_prod = this->inner_product_type(); TA_ASSERT(inner_prod == TensorProduct::Contraction || - inner_prod == TensorProduct::Hadamard); + inner_prod == TensorProduct::Hadamard || + inner_prod == TensorProduct::Scale); if (inner_prod == TensorProduct::Contraction) { TA_ASSERT(tot_x_tot); if constexpr (tot_x_tot) { @@ -577,8 +579,8 @@ class ContEngine : public BinaryEngine { } }; } - } // ToT x ToT - } else if (inner_prod == TensorProduct::General) { + } // ToT x T or T x ToT + } else if (inner_prod == TensorProduct::Scale) { TA_ASSERT(!tot_x_tot); constexpr bool tot_x_t = TiledArray::detail::is_tensor_of_tensor_v { std::conditional_t; - auto scal_op = [do_perm = this->permute_tiles_, - perm = this->permute_tiles_ ? inner(this->perm_) + auto scal_op = [perm = this->permute_tiles_ ? inner(this->perm_) : Permutation{}]( const left_tile_element_type& left, const right_tile_element_type& right) -> result_tile_element_type { using TiledArray::scale; if constexpr (tot_x_t) { - if (do_perm) + if (perm) return scale(left, right, perm); else return scale(left, right); } else if constexpr (tot_x_t) { - if (do_perm) + if (perm) return scale(right, left, perm); else return scale(right, left); diff --git a/src/TiledArray/expressions/product.h b/src/TiledArray/expressions/product.h index 381b1f485c..7111b7831b 100644 --- a/src/TiledArray/expressions/product.h +++ b/src/TiledArray/expressions/product.h @@ -39,6 +39,9 @@ enum class TensorProduct { Contraction, /// free, fused, and contracted indices General, + /// no indices on one, free indices on the other; only used for inner index + /// products in mixed nested products (ToT x T) + Scale, /// invalid Invalid = -1 }; @@ -59,7 +62,7 @@ inline TensorProduct compute_product_type(const IndexList& left_indices, result = TensorProduct::Contraction; } else if ((left_indices && !right_indices) || (!left_indices && right_indices)) { // used for ToT*T or T*ToT - result = TensorProduct::General; + result = TensorProduct::Scale; } return result; } diff --git a/tests/einsum.cpp b/tests/einsum.cpp index 3033936381..ea5529e5b8 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -718,6 +718,49 @@ BOOST_AUTO_TEST_SUITE_END() // einsum_tot BOOST_AUTO_TEST_SUITE(einsum_tot_t) +BOOST_AUTO_TEST_CASE(ilkj_nm_eq_ij_mn_times_kl) { + using t_type = DistArray, SparsePolicy>; + using tot_type = DistArray>, SparsePolicy>; + using matrix_il = TiledArray::detail::matrix_il>; + auto& world = TiledArray::get_default_world(); + Tensor lhs_elem_0_0( + Range{7, 2}, {49, 73, 28, 46, 12, 83, 29, 61, 61, 98, 57, 28, 96, 57}); + Tensor lhs_elem_0_1( + Range{7, 2}, {78, 15, 69, 55, 87, 94, 28, 94, 79, 30, 26, 88, 48, 74}); + Tensor lhs_elem_1_0( + Range{7, 2}, {70, 32, 25, 71, 6, 56, 4, 13, 72, 50, 15, 95, 52, 89}); + Tensor lhs_elem_1_1( + Range{7, 2}, {12, 29, 17, 68, 37, 79, 5, 52, 13, 35, 53, 54, 78, 71}); + Tensor lhs_elem_2_0( + Range{7, 2}, {77, 39, 34, 94, 16, 82, 63, 27, 75, 12, 14, 59, 3, 14}); + Tensor lhs_elem_2_1( + Range{7, 2}, {65, 90, 37, 41, 65, 75, 59, 16, 44, 85, 86, 11, 40, 24}); + Tensor lhs_elem_3_0( + Range{7, 2}, {77, 53, 11, 6, 99, 63, 46, 68, 83, 56, 76, 86, 91, 79}); + Tensor lhs_elem_3_1( + Range{7, 2}, {56, 11, 33, 90, 36, 38, 33, 54, 60, 21, 16, 28, 6, 97}); + matrix_il lhs_il{{lhs_elem_0_0, lhs_elem_0_1}, + {lhs_elem_1_0, lhs_elem_1_1}, + {lhs_elem_2_0, lhs_elem_2_1}, + {lhs_elem_3_0, lhs_elem_3_1}}; + TiledRange lhs_trange{{0, 2, 4}, {0, 2}}; + tot_type lhs(world, lhs_trange, lhs_il); + + TiledRange rhs_trange{{0, 2}, {0, 2, 4, 6}}; + t_type rhs(world, rhs_trange); + rhs.fill_random(); + + TiledRange ref_result_trange{lhs_trange.dim(0), rhs_trange.dim(1), + rhs_trange.dim(0)}; + tot_type ref_result(world, ref_result_trange); + // TODO compute ref_result + + tot_type result; + BOOST_REQUIRE_NO_THROW(result("i,l,k,j;n,m") = lhs("i,j;m,n") * rhs("k,l")); + + // TODO check result against ref_result +} + BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) { using t_type = DistArray, SparsePolicy>; using tot_type = DistArray>, SparsePolicy>; @@ -764,11 +807,7 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) { // tot_type result; // BOOST_REQUIRE_NO_THROW(result("i,k,j;m,n") = lhs("i,j;m,n") * rhs("j,k")); - // will try to make this work FIRST since this is used by the einsum code - // below - tot_type out; - out("i,l,k,j;n,m") = lhs("i,j;m,n") * rhs("k,l"); - // will try to make this work NEXT + // will try to make this work // tot_type out = einsum(lhs("i,j;m,n"), rhs("j,k"), "i,j,k;m,n"); }