Skip to content

Commit

Permalink
[XLA:GPU] Add RaggedAllToAllDecomposer pass.
Browse files Browse the repository at this point in the history
The pass rewrites `ragged-all-to-all` as a regular `all-to-all`.

This rewrite is not intended to be the production implementation of `ragged-all-to-all`, because it uses much more memory than necessary.

Adding this pass had the following goals:
  * Unblock end-to-end integration of `ragged-all-to-all` in XLA:GPU.
  * Serve as a reference implementation.
  * Help write end-to-end tests

Once we have a proper implementation with NCCL, this pass should be removed.

Integration into the GPU compilation pipeline will follow.

PiperOrigin-RevId: 702398142
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Dec 3, 2024
1 parent 97a8000 commit 015e42b
Show file tree
Hide file tree
Showing 4 changed files with 473 additions and 0 deletions.
41 changes: 41 additions & 0 deletions third_party/xla/xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3344,3 +3344,44 @@ xla_cc_test(
"@local_tsl//tsl/platform:statusor",
],
)

cc_library(
name = "ragged_all_to_all_decomposer",
srcs = ["ragged_all_to_all_decomposer.cc"],
hdrs = ["ragged_all_to_all_decomposer.h"],
deps = [
"//xla:comparison_util",
"//xla:literal_util",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:errors",
],
)

xla_cc_test(
name = "ragged_all_to_all_decomposer_test",
srcs = ["ragged_all_to_all_decomposer_test.cc"],
deps = [
":ragged_all_to_all_decomposer",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:filecheck",
"//xla/service:hlo_cse",
"//xla/service:hlo_runner",
"//xla/service:platform_util",
"//xla/tests:new_hlo_test_base",
"//xla/tests:test_utils",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/transforms/ragged_all_to_all_decomposer.h"

#include <cstdint>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/comparison_util.h"
#include "xla/hlo/ir/dfs_hlo_visitor.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"

namespace xla {
namespace gpu {

// Returns a multi-index offset for the ith row. The tensors are always ragged
// by the outmost dimension, `offsets` contains indexes of the outmost dimension
// and outher dimensions are 0.
absl::InlinedVector<HloInstruction*, 4> GetOffsetMultiIndex(
HloComputation* computation, HloInstruction* offsets, int64_t index,
int64_t rank) {
absl::InlinedVector<HloInstruction*, 4> result(
rank, computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));

HloInstruction* index_value =
computation->AddInstruction(HloInstruction::CreateSlice(
/*shape=*/ShapeUtil::MakeShape(S32, {1}),
/*operand=*/offsets,
/*start_indices=*/{index},
/*limit_indices=*/{index + 1},
/*strides=*/{1}));
result[0] = computation->AddInstruction(
HloInstruction::CreateReshape(/*shape=*/
ShapeUtil::MakeScalarShape(S32),
index_value));
return result;
}

// Pads the outermost dimension of the input tensor to double the size.
HloInstruction* PadOutermostDimension(HloComputation* computation,
HloInstruction* input) {
Shape padded_shape = input->shape();
PaddingConfig padding_config = MakeNoPaddingConfig(padded_shape.rank());
padding_config.mutable_dimensions(0)->set_edge_padding_high(
padded_shape.dimensions(0));

padded_shape.set_dimensions(0, 2 * padded_shape.dimensions(0));

HloInstruction* padding_value =
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(input->shape().element_type())));

return computation->AddInstruction(HloInstruction::CreatePad(
padded_shape, input, padding_value, padding_config));
}

// Takes a ragged tensor and a vector of chunk sizes. Returns a ragged tensor
// where padding is filled with zeros.
HloInstruction* FillPaddingWithZeros(HloComputation* computation,
HloInstruction* input,
HloInstruction* sizes) {
HloInstruction* zero = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));

// Create reduction computation.
auto embedded_builder = HloComputation::Builder("add");
auto lhs = embedded_builder.AddInstruction(
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "lhs"));
auto rhs = embedded_builder.AddInstruction(
HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "rhs"));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));

HloComputation* add_computation =
computation->parent()->AddEmbeddedComputation(embedded_builder.Build());

// Find total sizes of the significant data in the ragged tensor.
HloInstruction* total_size =
computation->AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeScalarShape(S32), sizes, zero, {0}, add_computation));

Shape iota_shape = ShapeUtil::MakeShape(S32, input->shape().dimensions());

HloInstruction* iota =
computation->AddInstruction(HloInstruction::CreateIota(iota_shape, 0));

HloInstruction* total_size_broadcast = computation->AddInstruction(
HloInstruction::CreateBroadcast(iota_shape, total_size, {}));

Shape mask_shape = ShapeUtil::MakeShape(PRED, iota_shape.dimensions());

// Get bitmask for the significant data in the ragged tensor.
HloInstruction* iota_mask =
computation->AddInstruction(HloInstruction::CreateCompare(
mask_shape, iota, total_size_broadcast, Comparison::Direction::kLt));

HloInstruction* padding_value =
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(input->shape().element_type())));

HloInstruction* zero_broadcast = computation->AddInstruction(
HloInstruction::CreateBroadcast(input->shape(), padding_value, {}));

// Fill padding with zeros.
return computation->AddInstruction(HloInstruction::CreateTernary(
input->shape(), HloOpcode::kSelect, iota_mask, input, zero_broadcast));
}

// Returns dense representation of the ragged input tensor.
//
// The dense representation is a tuple of slices of the input tensor, where each
// element of the tuple is an ragged row padded with zeros to the same size as
// the ragged input.
std::vector<HloInstruction*> RaggedToDense(HloComputation* computation,
HloInstruction* ragged_input,
HloInstruction* offsets,
HloInstruction* sizes) {
int64_t num_rows = offsets->shape().dimensions(0);

std::vector<HloInstruction*> sliced_operands;

for (int64_t i = 0; i < num_rows; ++i) {
auto offset_multi_index = GetOffsetMultiIndex(computation, offsets, i,
ragged_input->shape().rank());

HloInstruction* padded_input =
PadOutermostDimension(computation, ragged_input);

HloInstruction* row_slice =
computation->AddInstruction(HloInstruction::CreateDynamicSlice(
ragged_input->shape(), padded_input, offset_multi_index,
ragged_input->shape().dimensions()));

sliced_operands.push_back(row_slice);
}

return sliced_operands;
}

// Returns ragged representation of the dense output tensor.
HloInstruction* DenseToRagged(HloComputation* computation,
HloInstruction* dense_inputs,
HloInstruction* ragged_output,
HloInstruction* offsets, HloInstruction* sizes) {
int64_t num_rows = offsets->shape().dimensions(0);
int64_t rank = ragged_output->shape().rank();

Shape original_shape = ragged_output->shape();

HloInstruction* padded_ragged_output =
PadOutermostDimension(computation, ragged_output);

for (int64_t i = 0; i < num_rows; ++i) {
auto offset_multi_index = GetOffsetMultiIndex(
computation, offsets, i, padded_ragged_output->shape().rank());

HloInstruction* update = computation->AddInstruction(
HloInstruction::CreateGetTupleElement(dense_inputs, i));

padded_ragged_output =
computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
padded_ragged_output->shape(), padded_ragged_output, update,
offset_multi_index));
}

ragged_output = computation->AddInstruction(HloInstruction::CreateSlice(
original_shape, padded_ragged_output, std::vector<int64_t>(rank, 0),
original_shape.dimensions(), std::vector<int64_t>(rank, 1)));

ragged_output = FillPaddingWithZeros(computation, ragged_output, sizes);

return ragged_output;
}

// Rewrites a ragged all-to-all to a sequence dynamic-slicer, an all-to-all,
// and a sequence dynamic-update-slices.
absl::Status DecomposeRaggedAllToAll(HloInstruction* hlo,
HloComputation* computation,
HloModule* module) {
HloRaggedAllToAllInstruction* all_to_all =
Cast<HloRaggedAllToAllInstruction>(hlo);
HloInstruction* input_operand = all_to_all->mutable_operand(0);
HloInstruction* output_operand = all_to_all->mutable_operand(1);

HloInstruction* input_offsets = all_to_all->mutable_operand(2);
HloInstruction* send_sizes = all_to_all->mutable_operand(3);
HloInstruction* output_offsets = all_to_all->mutable_operand(4);
HloInstruction* recv_sizes = all_to_all->mutable_operand(5);

auto dense_input =
RaggedToDense(computation, input_operand, input_offsets, send_sizes);

std::vector<Shape> dense_input_shapes;
dense_input_shapes.reserve(dense_input.size());
for (auto* dense_input : dense_input) {
dense_input_shapes.push_back(dense_input->shape());
}

auto dense_output =
computation->AddInstruction(HloInstruction::CreateAllToAll(
ShapeUtil::MakeTupleShape(dense_input_shapes), dense_input,
all_to_all->device_list(),
/*constrain_layout=*/false,
/*channel_id=*/all_to_all->channel_id()));

auto* ragged_output = DenseToRagged(computation, dense_output, output_operand,
output_offsets, recv_sizes);

TF_RETURN_IF_ERROR(all_to_all->ReplaceAllUsesWith(ragged_output));
TF_RETURN_IF_ERROR(
computation->RemoveInstructionAndUnusedOperands(all_to_all));

return absl::OkStatus();
}

absl::StatusOr<bool> RaggedAllToAllDecomposer::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
bool changed = false;

for (auto computation : module->computations(execution_threads)) {
for (auto hlo : computation->MakeInstructionPostOrder()) {
if (HloPredicateIsNotOp<HloOpcode::kRaggedAllToAll>(hlo)) {
continue;
}
changed = true;
TF_RETURN_IF_ERROR(DecomposeRaggedAllToAll(hlo, computation, module));
}
}
return changed;
}

} // namespace gpu
} // namespace xla
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_TRANSFORMS_RAGGED_ALL_TO_ALL_DECOMPOSER_H_
#define XLA_SERVICE_GPU_TRANSFORMS_RAGGED_ALL_TO_ALL_DECOMPOSER_H_

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"

namespace xla {
namespace gpu {

// Rewrites a `ragged-all-to-all` as a regular `all-to-all`.
//
// A ragged tensor is converted into a dense representation by slicing each
// ragged row from the input and padding with zeros. Then, `all-to-all` is
// performed on the dense representation to exchange rows between replicas.
// Finally, the dense representation is converted back to ragged using
// `dynamic-update-slice` and filling padded values with zero.
//
// This pass is intended as a temporary solution to unblock end-to-end
// integration of `ragged-all-to-all` on GPU, to serve as a reference
// implementation and help with writing integration tests.
//
// TODO(b/379629619): Remove this pass once `ragged-all-to-all` is implemented
// natively on GPU with NCCL.
class RaggedAllToAllDecomposer : public HloModulePass {
public:
absl::string_view name() const override {
return "ragged-all-to-all-decomposer";
}

absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};

} // namespace gpu
} // namespace xla

#endif // XLA_SERVICE_GPU_TRANSFORMS_RAGGED_ALL_TO_ALL_DECOMPOSER_H_
Loading

0 comments on commit 015e42b

Please sign in to comment.