Skip to content

Commit

Permalink
[XLA:GPU] Add RaggedAllToAllDecomposer to the GPU compilation pipeline.
Browse files Browse the repository at this point in the history
Also add a small e2e test. More e2e tests are coming later.

PiperOrigin-RevId: 702684333
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Dec 4, 2024
1 parent 9db82cc commit e67b0c4
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 4 deletions.
9 changes: 9 additions & 0 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1);
opts.set_xla_pjrt_allow_auto_layout_in_hlo(false);
opts.set_xla_gpu_enable_scatter_determinism_expander(true);
opts.set_xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(true);
return opts;
}

Expand Down Expand Up @@ -2104,6 +2105,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"performance."
"Note that even when this flag is disabled, scatter operations may still "
"be deterministic, although with additional overhead."));
flag_list->push_back(tsl::Flag(
"xla_gpu_unsupported_enable_ragged_all_to_all_decomposer",
bool_setter_for(
&DebugOptions::
set_xla_gpu_unsupported_enable_ragged_all_to_all_decomposer),
debug_options->xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(),
"Internal: Enable the RaggedAllToAllDecomposer, an experimental pass "
"that rewrites ragged-all-to-all as a dense all-to-all operation."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,7 @@ cc_library(
"//xla/service/gpu/transforms:layout_assignment",
"//xla/service/gpu/transforms:move_copy_to_users",
"//xla/service/gpu/transforms:pipelined_p2p_rewriter",
"//xla/service/gpu/transforms:ragged_all_to_all_decomposer",
"//xla/service/gpu/transforms:reduce_scatter_creator",
"//xla/service/gpu/transforms:reduction_degenerate_dim_remover",
"//xla/service/gpu/transforms:reduction_dimension_grouper",
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ limitations under the License.
#include "xla/service/gpu/transforms/layout_assignment.h"
#include "xla/service/gpu/transforms/move_copy_to_users.h"
#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h"
#include "xla/service/gpu/transforms/ragged_all_to_all_decomposer.h"
#include "xla/service/gpu/transforms/reduce_scatter_creator.h"
#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h"
#include "xla/service/gpu/transforms/reduction_dimension_grouper.h"
Expand Down Expand Up @@ -850,6 +851,9 @@ absl::Status RunCollectiveOptimizationPasses(
const DebugOptions& debug_options = config.debug_options();

HloPassPipeline collectives_pipeline("collective-optimizations");
if (debug_options.xla_gpu_unsupported_enable_ragged_all_to_all_decomposer()) {
collectives_pipeline.AddPass<RaggedAllToAllDecomposer>();
}
collectives_pipeline.AddPass<AllReduceSimplifier>();
collectives_pipeline.AddPass<AllReduceFolder>();
collectives_pipeline.AddPass<AllReduceSplitter>();
Expand Down
11 changes: 9 additions & 2 deletions third_party/xla/xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ cc_library(
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
Expand Down Expand Up @@ -2479,15 +2478,23 @@ xla_test(
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//xla:error_spec",
"//xla:literal",
"//xla:literal_util",
"//xla:types",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/service:computation_placer_hdr",
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service/gpu:backend_configs_cc",
"//xla/stream_executor:device_description",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
],
Expand Down
91 changes: 90 additions & 1 deletion third_party/xla/xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,36 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <memory>
#include <utility>
#include <variant>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/error_spec.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/computation_placer.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tests/test_macros.h"
#include "xla/tests/test_utils.h"
#include "xla/types.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

Expand Down Expand Up @@ -1526,5 +1535,85 @@ ENTRY entry {
EXPECT_TRUE(executable->has_module());
}

class RaggedAllToAllTestE2E : public CollectiveOpsTestE2E {
public:
struct RaggedTensor {
RaggedTensor(absl::Span<float const> data,
absl::Span<const int32_t> offsets,
absl::Span<const int32_t> sizes)
: data(LiteralUtil::CreateR1<float>(data)),
offsets(LiteralUtil::CreateR1<int32_t>(offsets)),
sizes(LiteralUtil::CreateR1<int32_t>(sizes)) {}

Literal data;
Literal offsets;
Literal sizes;
};

std::vector<Literal*> ToReplicaLiteralPtrs(RaggedTensor& input,
RaggedTensor& output) {
return {&input.data, &output.data, &input.offsets,
&input.sizes, &output.offsets, &output.sizes};
}
};

TEST_F(RaggedAllToAllTestE2E, RaggedAllToAll) {
absl::string_view kModuleReplicatedStr = R"(
HloModule module, entry_computation_layout={(f32[4], f32[4], s32[2], s32[2],
s32[2], s32[2])->f32[4]}, num_partitions=1
ENTRY entry {
input = f32[4] parameter(0)
output = f32[4] parameter(1)
input_offsets = s32[2] parameter(2)
send_sizes = s32[2] parameter(3)
output_offsets = s32[2] parameter(4)
recv_sizes = s32[2] parameter(5)
ROOT ra2a = f32[4] ragged-all-to-all(input, output, input_offsets,
send_sizes, output_offsets, recv_sizes), replica_groups={{0,1}}
}
)";

const int64_t kNumReplicas = 2;
const int64_t kNumPartitions = 1;
SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);

HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas * kNumPartitions);

TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config));

// before ra2a:
// r0: c0 = {1}, c1 = {2}
// r1: c0 = {3, 4}, c1 = {5}
RaggedTensor r0_input(/*data=*/{1., 2., 0., 0.},
/*offsets=*/{0, 1}, /*sizes=*/{1, 1});
RaggedTensor r1_input(/*data=*/{3., 4., 5., 0.},
/*offsets=*/{0, 2}, /*sizes=*/{2, 1});

// after ra2a:
// r0: c0 = {1}, c1 = {3, 4}
// r1: c0 = {2}, c1 = {5}
RaggedTensor r0_output(/*data=*/{0., 0., 0., 0.},
/*offsets=*/{0, 1}, /*sizes=*/{1, 2});
RaggedTensor r1_output(/*data=*/{0., 0., 0., 0.},
/*offsets=*/{0, 1}, /*sizes=*/{1, 1});

std::vector<std::vector<Literal*>> input_literal_ptrs = {
ToReplicaLiteralPtrs(r0_input, r0_output),
ToReplicaLiteralPtrs(r1_input, r1_output)};

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
HloTestBase::ExecuteReplicated(std::move(module), input_literal_ptrs,
/*num_replicas=*/kNumReplicas,
/*run_hlo_passes=*/true,
/*device_assignment=*/nullptr));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1., 3., 4., 0}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2., 5., 0., 0.}, results[1]);
}

} // namespace
} // namespace xla
5 changes: 4 additions & 1 deletion third_party/xla/xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ message DebugOptions {
// with priority fusion.
bool xla_gpu_experimental_enable_triton_softmax_priority_fusion = 325;

// Internal testing flag to switch RaggedAllToAllDecomposer on or off.
bool xla_gpu_unsupported_enable_ragged_all_to_all_decomposer = 350;

// Internal debug/testing flag to switch Triton GEMM fusions on or off.
bool xla_gpu_unsupported_enable_triton_gemm = 322;

Expand Down Expand Up @@ -1069,7 +1072,7 @@ message DebugOptions {
// be deterministic, although with additional overhead.
bool xla_gpu_enable_scatter_determinism_expander = 345;

// Next id: 350
// Next id: 351

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit e67b0c4

Please sign in to comment.