Skip to content

Commit

Permalink
Use CUTENSOR for Tessellate backpropagation in supported cases (#2460)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun authored Jul 1, 2024
1 parent fa569da commit 93e04a8
Show file tree
Hide file tree
Showing 8 changed files with 348 additions and 79 deletions.
91 changes: 91 additions & 0 deletions ci_test/unit_tests/test_unit_layer_tessellate_backprop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import lbann
import numpy as np
import test_util


@test_util.lbann_test(check_gradients=True)
def test_tessellate_pad():
# Prepare reference output
np.random.seed(20240627)
x = np.random.rand(37, 20, 1)
reference_numpy = np.tile(x, (1, 1, 5))

tester = test_util.ModelTester()

x = tester.inputs_like(x)
reference = tester.make_reference(reference_numpy)

# Test layer
y = lbann.Tessellate(x, dims=[20, 5])

# Set test loss
tester.set_loss(lbann.MeanSquaredError(y, reference))
tester.set_check_gradients_tensor(lbann.Square(y))

return tester


@test_util.lbann_test(check_gradients=True)
def test_tessellate_scalar():
# Prepare reference output
np.random.seed(20240627)
x = np.random.rand(37, 1, 1, 1)
reference_numpy = np.tile(x, (1, 5, 6, 7))

tester = test_util.ModelTester()

x = tester.inputs_like(x)
reference = tester.make_reference(reference_numpy)

# Test layer
y = lbann.Tessellate(x, dims=[5, 6, 7])

# Set test loss
tester.set_loss(lbann.MeanSquaredError(y, reference))
tester.set_check_gradients_tensor(lbann.Square(y))

return tester


@test_util.lbann_test(check_gradients=True)
def test_tessellate_repro1():
# Prepare reference output
np.random.seed(20240627)
x = np.random.rand(16, 3, 1, 1)
reference_numpy = np.tile(x, (1, 16, 3, 3))

tester = test_util.ModelTester()

x = tester.inputs_like(x)
reference = tester.make_reference(reference_numpy)

# Test layer
y = lbann.Tessellate(x, dims=[3 * 16, 3, 3])

# Set test loss
tester.set_loss(lbann.MeanSquaredError(y, reference))
tester.set_check_gradients_tensor(lbann.Square(y))

return tester


@test_util.lbann_test(check_gradients=True)
def test_tessellate_repro2():
# Prepare reference output
np.random.seed(20240627)
x = np.random.rand(37, 16, 1, 1)
reference_numpy = np.tile(x, (1, 1, 64, 64))

tester = test_util.ModelTester()

x = tester.inputs_like(x)
reference = tester.make_reference(reference_numpy)

# Test layer
y = lbann.Tessellate(x, dims=[16, 64, 64])

# Set test loss
tester.set_loss(lbann.MeanSquaredError(y, reference))
tester.set_check_gradients_tensor(lbann.Square(y))

return tester
40 changes: 34 additions & 6 deletions include/lbann/layers/transform/tessellate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,33 @@ class tessellate_layer : public data_type_layer<TensorDataType>
}
auto& local_gradient_wrt_input = m_input_v->Matrix();

// Apply back prop with local data
/// @todo Support >3 dimensions
bp_compute_3d(input_dims,
output_dims,
gradient_wrt_output,
local_gradient_wrt_input);
// Check if multi-dimensional reduction tessellate backprop is supported
#ifdef LBANN_HAS_CUTENSOR
bool multidim_reduce = true;
for (size_t i = 0; i < output_dims.size(); ++i) {
if (input_dims[i] != 1 && output_dims[i] != input_dims[i]) {
multidim_reduce = false;
break;
}
}
if (multidim_reduce) {
// Sizes either match or are broadcast from dimension-1:
// compute tessellate backprop via multi-dimensional reduction
bp_compute_cutensor(input_dims,
output_dims,
gradient_wrt_output,
local_gradient_wrt_input);
}
else
#endif
{
// Apply back prop with local data
/// @todo Support >3 dimensions
bp_compute_3d(input_dims,
output_dims,
gradient_wrt_output,
local_gradient_wrt_input);
}

// Accumulate local error signals, if needed
if (m_input_v->DistData() != gradient_wrt_input.DistData()) {
Expand Down Expand Up @@ -254,6 +275,13 @@ class tessellate_layer : public data_type_layer<TensorDataType>
const std::vector<int>& output_dims,
const AbsDistMatrixType& gradient_wrt_output,
AbsMatrixType& gradient_wrt_input);

#ifdef LBANN_HAS_CUTENSOR
void bp_compute_cutensor(const std::vector<int>& input_dims,
const std::vector<int>& output_dims,
const AbsDistMatrixType& gradient_wrt_output,
AbsMatrixType& gradient_wrt_input);
#endif
};

template <typename T, data_layout L, El::Device D>
Expand Down
38 changes: 23 additions & 15 deletions include/lbann/utils/cutensor_support.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
namespace lbann {

namespace cutensor {
using DimsType = ColMajorDims<int64_t>;
using StridesType = ColMajorStrides<int64_t>;
using ModesType = std::vector<int32_t>;
} // namespace cutensor
using DimsType = ColMajorDims<int64_t>;
using StridesType = ColMajorStrides<int64_t>;
using ModesType = std::vector<int32_t>;
} // namespace cutensor

template <typename CppType>
struct CUDATypeT;
Expand All @@ -65,31 +65,36 @@ struct CUDATypeT<__half>
{
typedef float scalar_type;
static constexpr auto value = CUDA_R_16F;
static constexpr auto compute_type = CUTENSOR_COMPUTE_16F;
};

template <>
struct CUDATypeT<float>
{
typedef float scalar_type;
static constexpr auto value = CUDA_R_32F;
static constexpr auto compute_type = CUTENSOR_COMPUTE_32F;
};
template <>
struct CUDATypeT<double>
{
typedef double scalar_type;
static constexpr auto value = CUDA_R_64F;
static constexpr auto compute_type = CUTENSOR_COMPUTE_64F;
};
template <>
struct CUDATypeT<El::Complex<float>>
{
typedef El::Complex<float> scalar_type;
static constexpr auto value = CUDA_C_32F;
static constexpr auto compute_type = CUTENSOR_COMPUTE_32F;
};
template <>
struct CUDATypeT<El::Complex<double>>
{
typedef El::Complex<double> scalar_type;
static constexpr auto value = CUDA_C_64F;
static constexpr auto compute_type = CUTENSOR_COMPUTE_64F;
};

template <typename CppType>
Expand All @@ -113,7 +118,6 @@ static cutensorHandle_t* get_handle_ptr()
return &handle;
}


static inline cutensor::ModesType make_modes(size_t const ndims)
{
std::vector<int32_t> modes(ndims + 1); // Add the sample dim.
Expand All @@ -122,24 +126,28 @@ static inline cutensor::ModesType make_modes(size_t const ndims)
}

template <typename DataT>
static std::string get_desc_key(
El::Matrix<DataT, El::Device::GPU> const& mat,
cutensor::DimsType const& dims_in)
static std::string get_desc_key(El::Matrix<DataT, El::Device::GPU> const& mat,
cutensor::DimsType const& dims_in)
{
auto const& dims = dims_in.get();
std::ostringstream oss;
oss << mat.Height() << "," << mat.Width() << "," << mat.LDim() << ";"
<< dims.front();
for (size_t ii = 1; ii < dims.size(); ++ii)
oss << "," << dims[ii];
oss << mat.Height() << "," << mat.Width() << "," << mat.LDim() << ";";
if (dims.size() == 0) {
oss << "scalar";
}
else {
oss << dims.front();
for (size_t ii = 1; ii < dims.size(); ++ii)
oss << "," << dims[ii];
}
oss << ";" << lbann::TypeName<DataT>();
return oss.str();
}

template <typename DataT>
static cutensorTensorDescriptor_t get_descriptor(
El::Matrix<DataT, El::Device::GPU> const& mat,
cutensor::DimsType const& dims)
static cutensorTensorDescriptor_t
get_descriptor(El::Matrix<DataT, El::Device::GPU> const& mat,
cutensor::DimsType const& dims)
{
/** @brief Keep track of descriptors so we don't have to repeatedly
* rebuild them.
Expand Down
14 changes: 8 additions & 6 deletions include/lbann/utils/tensor_dims_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class NamedVector
explicit NamedVector(std::vector<T> const& v) : m_data{v} {}
explicit NamedVector(std::vector<T>&& v) : m_data{std::move(v)} {}
template <typename U>
explicit NamedVector(std::vector<U> const& v) : m_data(v.begin(), v.end()) {}

explicit NamedVector(std::vector<U> const& v) : m_data(v.begin(), v.end())
{}

NamedVector(NamedVector const& other) = default;
NamedVector(NamedVector&& other) = default;
Expand Down Expand Up @@ -234,10 +234,12 @@ auto get_strides_as(ColMajorDims<DimT> const& dims)
size_t const ndims = dim_vec.size();

std::vector<StrideT> strides;
strides.reserve(ndims);
strides.push_back(StrideT{1});
for (size_t ii = 0UL; ii < ndims - 1; ++ii)
strides.push_back(strides[ii] * static_cast<StrideT>(dim_vec[ii]));
if (ndims > 0) {
strides.reserve(ndims);
strides.push_back(StrideT{1});
for (size_t ii = 0UL; ii < ndims - 1; ++ii)
strides.push_back(strides[ii] * static_cast<StrideT>(dim_vec[ii]));
}
return ColMajorStrides<StrideT>(strides);
}

Expand Down
Loading

0 comments on commit 93e04a8

Please sign in to comment.