Skip to content

Commit

Permalink
introduced BlkTsrExpr::{{set_,}trange_lobound,preserve_lobound}() tha…
Browse files Browse the repository at this point in the history
…t allow to use block tensor expressions even with DistArrays that have non-zero lobound
  • Loading branch information
evaleev committed Sep 11, 2024
1 parent 57907cc commit 6e8624a
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 21 deletions.
54 changes: 42 additions & 12 deletions src/TiledArray/expressions/blk_tsr_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,29 @@ class BlkTsrEngineBase : public LeafEngine<Derived> {
using LeafEngine_::array_;

container::svector<std::size_t>
lower_bound_; ///< Lower bound of the tile block
lower_bound_; ///< Tile coordinates of the lower bound of the tile block
///< in the host array
container::svector<std::size_t>
upper_bound_; ///< Upper bound of the tile block
upper_bound_; ///< Tile coordinates of the upper bound of the tile block
///< in the host array
std::optional<Range::index_type>
trange_lobound_; ///< Lobound of the result trange, modulo permutation
///< (i.e. referring to the modes of the host array)

public:
template <typename Array, bool Alias>
BlkTsrEngineBase(const BlkTsrExpr<Array, Alias>& expr)
: LeafEngine_(expr),
lower_bound_(expr.lower_bound()),
upper_bound_(expr.upper_bound()) {}
upper_bound_(expr.upper_bound()),
trange_lobound_(expr.trange_lobound()) {}

template <typename Array, typename Scalar>
BlkTsrEngineBase(const ScalBlkTsrExpr<Array, Scalar>& expr)
: LeafEngine_(expr),
lower_bound_(expr.lower_bound()),
upper_bound_(expr.upper_bound()) {}
upper_bound_(expr.upper_bound()),
trange_lobound_(expr.trange_lobound()) {}

/// Non-permuting tiled range factory function

Expand All @@ -199,9 +206,12 @@ class BlkTsrEngineBase : public LeafEngine<Derived> {
if (lower_d != upper_d) {
auto i = lower_d;
const auto base_d = trange[d].tile(i).first;
trange1_data.emplace_back(0ul);
const auto trange1_lobound =
trange_lobound_ ? (*trange_lobound_)[d] : 0ul;
trange1_data.emplace_back(trange1_lobound);
for (; i < upper_d; ++i)
trange1_data.emplace_back(trange[d].tile(i).second - base_d);
trange1_data.emplace_back(trange[d].tile(i).extent() +
trange1_data.back());
// Add the trange1 to the tiled range data
trange_data.emplace_back(trange1_data.begin(), trange1_data.end());
trange1_data.resize(0ul);
Expand Down Expand Up @@ -241,9 +251,12 @@ class BlkTsrEngineBase : public LeafEngine<Derived> {
// Copy, shift, and permute the tiling of the block
auto i = lower_i;
const auto base_d = trange[inv_perm_d].tile(i).first;
trange1_data.emplace_back(0ul);
const auto trange1_lobound =
trange_lobound_ ? (*trange_lobound_)[inv_perm_d] : 0ul;
trange1_data.emplace_back(trange1_lobound);
for (; i < upper_i; ++i)
trange1_data.emplace_back(trange[inv_perm_d].tile(i).second - base_d);
trange1_data.emplace_back(trange[inv_perm_d].tile(i).extent() +
trange1_data.back());

// Add the trange1 to the tiled range data
trange_data.emplace_back(trange1_data.begin(), trange1_data.end());
Expand Down Expand Up @@ -341,6 +354,7 @@ class BlkTsrEngine
protected:
// Import base class variables to this scope
using BlkTsrEngineBase_::lower_bound_;
using BlkTsrEngineBase_::trange_lobound_;
using BlkTsrEngineBase_::upper_bound_;
using ExprEngine_::implicit_permute_inner_;
using ExprEngine_::implicit_permute_outer_;
Expand Down Expand Up @@ -391,8 +405,12 @@ class BlkTsrEngine
const auto lower_d = lower[d];
const auto upper_d = upper[d];
if (lower_d != upper_d) {
// element lobound of the block in the host
const auto base_d = trange[d].tile(lower_d).first;
range_shift.emplace_back(-base_d);
// element lobound of the target of this expression
const auto target_base_d =
trange_lobound_ ? (*trange_lobound_)[d] : 0ul;
range_shift.emplace_back(target_base_d - base_d);
} else {
range_shift.emplace_back(0l);
}
Expand Down Expand Up @@ -427,8 +445,11 @@ class BlkTsrEngine
const auto lower_d = lower[d];
const auto upper_d = upper[d];
if (lower_d != upper_d) {
// element lobound of the block in the host
const auto base_d = trange[d].tile(lower_d).first;
range_shift[perm_d] = -base_d;
// element lobound of the target of this expression
const auto target_base_d = trange_lobound_ ? (*trange_lobound_)[d] : 0;
range_shift[perm_d] = target_base_d - base_d;
}
}

Expand Down Expand Up @@ -496,6 +517,7 @@ class ScalBlkTsrEngine
protected:
// Import base class variables to this scope
using BlkTsrEngineBase_::lower_bound_;
using BlkTsrEngineBase_::trange_lobound_;
using BlkTsrEngineBase_::upper_bound_;
using ExprEngine_::implicit_permute_inner_;
using ExprEngine_::implicit_permute_outer_;
Expand Down Expand Up @@ -549,8 +571,12 @@ class ScalBlkTsrEngine
const auto lower_d = lower[d];
const auto upper_d = upper[d];
if (lower_d != upper_d) {
// element lobound of the block in the host
const auto base_d = trange[d].tile(lower_d).first;
range_shift.emplace_back(-base_d);
// element lobound of the target of this expression
const auto target_base_d =
trange_lobound_ ? (*trange_lobound_)[d] : 0ul;
range_shift.emplace_back(target_base_d - base_d);
} else
range_shift.emplace_back(0);
}
Expand Down Expand Up @@ -584,8 +610,12 @@ class ScalBlkTsrEngine
const auto lower_d = lower[d];
const auto upper_d = upper[d];
if (lower_d != upper_d) {
// element lobound of the block in the host
const auto base_d = trange[d].tile(lower_d).first;
range_shift[perm_d] = -base_d;
// element lobound of the target of this expression
const auto target_base_d =
trange_lobound_ ? (*trange_lobound_)[d] : 0ul;
range_shift[perm_d] = target_base_d - base_d;
}
}

Expand Down
36 changes: 36 additions & 0 deletions src/TiledArray/expressions/blk_tsr_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <TiledArray/expressions/unary_expr.h>
#include "blk_tsr_engine.h"

#include <optional>

namespace TiledArray {
namespace expressions {

Expand Down Expand Up @@ -118,6 +120,10 @@ class BlkTsrExprBase : public Expr<Derived> {
lower_bound_; ///< Lower bound of the tile block
container::svector<std::size_t>
upper_bound_; ///< Upper bound of the tile block
/// If non-null, element lobound of the expression trange (else zeros will be
/// used) Fusing permutation does not affect this (i.e. this refers to the
/// modes of the host array).
std::optional<Range::index_type> trange_lobound_;

void check_valid() const {
TA_ASSERT(array_);
Expand Down Expand Up @@ -285,6 +291,36 @@ class BlkTsrExprBase : public Expr<Derived> {
/// \return The block upper bound
const auto& upper_bound() const { return upper_bound_; }

/// Sets result trange lobound
/// @param[in] trange_lobound The result trange lobound
template <typename Index1,
typename = std::enable_if_t<
TiledArray::detail::is_integral_range_v<Index1>>>
Derived& set_trange_lobound(const Index1& trange_lobound) {
trange_lobound_.emplace(std::begin(trange_lobound),
std::end(trange_lobound));
return static_cast<Derived&>(*this);
}

/// Sets result trange lobound
/// @param[in] trange_lobound The result trange lobound
template <typename Integer,
typename = std::enable_if_t<std::is_integral_v<Integer>>>
Derived& set_trange_lobound(std::initializer_list<Integer> trange_lobound) {
return this->set_trange_lobound<std::initializer_list<Integer>>(
trange_lobound);
}

/// Sets result trange lobound such that the tile lobounds are not changed
Derived& preserve_lobound() {
return set_trange_lobound(
array_.trange().make_tile_range(lower_bound()).lobound());
}

/// @return optional to result trange lobound; if null, the result trange
/// lobound is zero
const auto& trange_lobound() const { return trange_lobound_; }

}; // class BlkTsrExprBase

/// Block expression
Expand Down
11 changes: 11 additions & 0 deletions src/TiledArray/expressions/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

#include <TiledArray/tensor/type_traits.h>

#include <range/v3/algorithm/equal.hpp>
#include <range/v3/range/conversion.hpp>
#include <range/v3/view/zip_with.hpp>

Expand Down Expand Up @@ -464,6 +465,16 @@ class Expr {
// set even though this is a requirement.
#endif // NDEBUG

// Assignment to block expression uses trange of the array it is bounded to
// Assert that the user did not try to override the trange by accident using
// set_trange_lobound or at least that it matches tsr.array's trange
TA_ASSERT(!tsr.trange_lobound().has_value() ||
(ranges::equal(tsr.trange_lobound().value(),
tsr.array()
.trange()
.make_tile_range(tsr.lower_bound())
.lobound())));

// Get the target world.
World& world = tsr.array().world();

Expand Down
13 changes: 12 additions & 1 deletion src/TiledArray/expressions/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

#include <type_traits>


namespace TiledArray::expressions {

template <typename>
Expand All @@ -43,6 +42,10 @@ class BlkTsrExpr;
template <typename, typename>
class ScalBlkTsrExpr;

/// used to indicate that block tensor expression should preserve the underlying
/// tensor's trange lobound
struct preserve_lobound_t {};

template <typename>
struct is_aliased : std::true_type {};

Expand All @@ -68,6 +71,14 @@ class ScalTsrExpr;
template <typename, typename, typename>
class ScalTsrEngine;

} // namespace TiledArray::expressions

namespace TiledArray {

/// used to tag block tensor expression methods that preserve the underlying
/// tensor's trange lobound
inline constexpr expressions::preserve_lobound_t preserve_lobound;

} // namespace TiledArray

#endif // TILEDARRAY_EXPRESSIONS_FWD_H__INCLUDED
Loading

0 comments on commit 6e8624a

Please sign in to comment.