Skip to content

Commit

Permalink
add cuda error debug info.add all2all test
Browse files Browse the repository at this point in the history
  • Loading branch information
zjjott committed Apr 22, 2024
1 parent 7866099 commit 947a478
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 5 deletions.
41 changes: 37 additions & 4 deletions xla/hlo/experimental/auto_reorder/auto_reorder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,30 @@ ENTRY %elementwise {
HloInstruction* param = async_builder.AddInstruction(
HloInstruction::CreateParameter(0, input_shape, "pasync"));
async_builder.AddInstruction(HloInstruction::CreateReduceScatter(
output_shape, {param}, MakeReduction(type, module), {}, false,
output_shape, {param}, MakeReduction(type, module), CreateReplicaGroups({{0, 1}}), false,
std::nullopt, false, 0));
HloComputation* reduction =
module->AddEmbeddedComputation(async_builder.Build());

return reduction;
}
HloComputation* MakeAll2All(Shape input_shape,HloModule* module){
HloComputation::Builder async_builder("AsyncOp");
HloInstruction* param = async_builder.AddInstruction(
HloInstruction::CreateParameter(0, input_shape, "pasync"));
HloInstruction* param1 = async_builder.AddInstruction(
HloInstruction::CreateParameter(0, input_shape, "pasync"));
async_builder.AddInstruction(HloInstruction::CreateAllToAll(
input_shape, {param,param1},
/*replica_groups=*/CreateReplicaGroups({{0, 1}}),
/*constrain_layout=*/false, /*channel_id=*/std::nullopt,
/*split_dimension*/ 0
));
HloComputation* reduction =
module->AddEmbeddedComputation(async_builder.Build());

return reduction;
}
std::string GetInstructionsOrderString(HloModule* hlo_module) {
auto insts = hlo_module->schedule()
.sequence(hlo_module->entry_computation())
Expand Down Expand Up @@ -511,10 +528,26 @@ ENTRY %elementwise {

auto add_reduce_scatter =
builder.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, reducescater_ret1, reducescater_ret0));
// root: d01_dot_p0,d23_dot_p1
reduce_shape, HloOpcode::kAdd, reducescater_ret1, reducescater_ret0));
auto all2all_ret = builder.AddInstruction(HloInstruction::CreateAllToAll(
reduce_shape, {add_reduce_scatter},
/*replica_groups=*/CreateReplicaGroups({{0, 1}}),
/*constrain_layout=*/false, /*channel_id=*/std::nullopt,
/*split_dimension*/ 0
));
// auto all2all_op = MakeAll2All(shape,module);
// HloInstruction* all2all_start =
// builder.AddInstruction(HloInstruction::CreateAsyncStart(
// reduce_shape, {add_reduce_scatter,add_reduce_scatter}, all2all_op,
// /*async_group_id=*/std::nullopt,
// /*async_execution_thread=*/"parallel_thread"));
// auto all2all_ret =
// builder.AddInstruction(HloInstruction::CreateAsyncDone(
// reduce_shape, all2all_start, all2all_op,
// /*async_group_id=*/std::nullopt,
// /*async_execution_thread=*/"parallel_thread"));
auto ret = builder.AddInstruction(HloInstruction::CreateTuple(
{d01_dot_p0_add_p2, d23_dot_p1_add_p3, add_reduce_scatter}));
{d01_dot_p0_add_p2, d23_dot_p1_add_p3,all2all_ret}));
auto computation = builder.Build();
computation->set_root_instruction(ret);
auto entry_computation =
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/cuda/cuda_asm_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ tsl::StatusOr<std::vector<uint8_t>> LinkUsingNvlink(
tsl::SubProcess process;
process.SetProgram(bin_path, args);
process.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);

VLOG(5)<<"subprocess running:"<<bin_path<<"args: "<< absl::StrJoin(args, " ");
TF_RET_CHECK(process.Start());
std::string stderr_output;
int exit_status = process.Communicate(
Expand Down
2 changes: 2 additions & 0 deletions xla/stream_executor/gpu/asm_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ tsl::StatusOr<std::vector<uint8_t>> CompileGpuAsm(int cc_major, int cc_minor,

ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args);
ptxas_info_dumper.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
VLOG(5)<<"subprocess running:"<<ptxas_path<<"args: "<< absl::StrJoin(ptxas_args, " ");
if (!ptxas_info_dumper.Start()) {
return tsl::errors::Internal("Failed to launch ptxas");
}
Expand Down Expand Up @@ -418,6 +419,7 @@ tsl::StatusOr<std::vector<uint8_t>> BundleGpuAsm(
}
fatbinary.SetProgram(fatbinary_path, fatbinary_args);
fatbinary.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
VLOG(5)<<"subprocess running:"<<fatbinary_path<<"args: "<< absl::StrJoin(fatbinary_args, " ");
if (!fatbinary.Start()) {
return tsl::errors::Internal("Failed to launch fatbinary.");
}
Expand Down

0 comments on commit 947a478

Please sign in to comment.