From 947a478538f53d96fc7732002fea910f3342d90f Mon Sep 17 00:00:00 2001 From: zjjott Date: Mon, 22 Apr 2024 14:49:51 +0800 Subject: [PATCH] add cuda error debug info.add all2all test --- .../auto_reorder/auto_reorder_test.cc | 41 +++++++++++++++++-- xla/stream_executor/cuda/cuda_asm_compiler.cc | 2 +- xla/stream_executor/gpu/asm_compiler.cc | 2 + 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc b/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc index 5276265448a98..ff063426accf2 100644 --- a/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc @@ -218,13 +218,30 @@ ENTRY %elementwise { HloInstruction* param = async_builder.AddInstruction( HloInstruction::CreateParameter(0, input_shape, "pasync")); async_builder.AddInstruction(HloInstruction::CreateReduceScatter( - output_shape, {param}, MakeReduction(type, module), {}, false, + output_shape, {param}, MakeReduction(type, module), CreateReplicaGroups({{0, 1}}), false, std::nullopt, false, 0)); HloComputation* reduction = module->AddEmbeddedComputation(async_builder.Build()); return reduction; } + HloComputation* MakeAll2All(Shape input_shape,HloModule* module){ + HloComputation::Builder async_builder("AsyncOp"); + HloInstruction* param = async_builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "pasync")); + HloInstruction* param1 = async_builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "pasync")); + async_builder.AddInstruction(HloInstruction::CreateAllToAll( + input_shape, {param,param1}, + /*replica_groups=*/CreateReplicaGroups({{0, 1}}), + /*constrain_layout=*/false, /*channel_id=*/std::nullopt, + /*split_dimension*/ 0 + )); + HloComputation* reduction = + module->AddEmbeddedComputation(async_builder.Build()); + + return reduction; + } std::string GetInstructionsOrderString(HloModule* hlo_module) { auto insts = hlo_module->schedule() .sequence(hlo_module->entry_computation()) @@ -511,10 +528,26 @@ ENTRY %elementwise { auto add_reduce_scatter = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kAdd, reducescater_ret1, reducescater_ret0)); - // root: d01_dot_p0,d23_dot_p1 + reduce_shape, HloOpcode::kAdd, reducescater_ret1, reducescater_ret0)); + auto all2all_ret = builder.AddInstruction(HloInstruction::CreateAllToAll( + reduce_shape, {add_reduce_scatter}, + /*replica_groups=*/CreateReplicaGroups({{0, 1}}), + /*constrain_layout=*/false, /*channel_id=*/std::nullopt, + /*split_dimension*/ 0 + )); + // auto all2all_op = MakeAll2All(shape,module); + // HloInstruction* all2all_start = + // builder.AddInstruction(HloInstruction::CreateAsyncStart( + // reduce_shape, {add_reduce_scatter,add_reduce_scatter}, all2all_op, + // /*async_group_id=*/std::nullopt, + // /*async_execution_thread=*/"parallel_thread")); + // auto all2all_ret = + // builder.AddInstruction(HloInstruction::CreateAsyncDone( + // reduce_shape, all2all_start, all2all_op, + // /*async_group_id=*/std::nullopt, + // /*async_execution_thread=*/"parallel_thread")); auto ret = builder.AddInstruction(HloInstruction::CreateTuple( - {d01_dot_p0_add_p2, d23_dot_p1_add_p3, add_reduce_scatter})); + {d01_dot_p0_add_p2, d23_dot_p1_add_p3,all2all_ret})); auto computation = builder.Build(); computation->set_root_instruction(ret); auto entry_computation = diff --git a/xla/stream_executor/cuda/cuda_asm_compiler.cc b/xla/stream_executor/cuda/cuda_asm_compiler.cc index 50df8ba769964..6f5b8318fe227 100644 --- a/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -101,7 +101,7 @@ tsl::StatusOr> LinkUsingNvlink( tsl::SubProcess process; process.SetProgram(bin_path, args); process.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); - + VLOG(5)<<"subprocess running:"<> CompileGpuAsm(int cc_major, int cc_minor, ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); + VLOG(5)<<"subprocess running:"<> BundleGpuAsm( } fatbinary.SetProgram(fatbinary_path, fatbinary_args); fatbinary.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); + VLOG(5)<<"subprocess running:"<