diff --git a/xla/hlo/experimental/auto_reorder/BUILD b/xla/hlo/experimental/auto_reorder/BUILD index 3a4edaee97078..dd3094b724153 100644 --- a/xla/hlo/experimental/auto_reorder/BUILD +++ b/xla/hlo/experimental/auto_reorder/BUILD @@ -88,6 +88,11 @@ cc_library( ":auto_reorder_solver" ], ) +tf_proto_library( + name = "instr_profile_info_proto", + srcs = ["instr_profile_info.proto"], +) + cc_library( name="convert_xplane", srcs=["convert_xplane.cc"], @@ -97,6 +102,7 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_proto_cc", + "//xla:shape_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -111,6 +117,8 @@ cc_library( "@tsl//tsl/profiler/utils:xplane_schema", "@tsl//tsl/profiler/utils:xplane_utils", "@tsl//tsl/profiler/utils:xplane_visitor", + "@com_google_protobuf//:protobuf", + ":instr_profile_info_proto_cc" ] ) xla_cc_test( @@ -135,16 +143,22 @@ xla_cc_test( ], ) -cc_binary( +xla_cc_binary( name="convert_xplane_tools", + linkopts = [ + "-Wl,--allow-multiple-definition", + "-lstdc++fs", # For std::filesystem + ], srcs=["convert_xplane_bin.cc"], deps=[ ":convert_xplane", + ":auto_reorder", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", "//xla/service:latency_hiding_scheduler", "//xla/service/gpu:gpu_hlo_schedule", "//xla/service:gpu_plugin", + "//xla:device_util", "@com_google_absl//absl/log", "@tsl//tsl/platform:statusor", diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder.cc b/xla/hlo/experimental/auto_reorder/auto_reorder.cc index 32a3e86ba5b35..3a66801c2ebc1 100644 --- a/xla/hlo/experimental/auto_reorder/auto_reorder.cc +++ b/xla/hlo/experimental/auto_reorder/auto_reorder.cc @@ -171,13 +171,15 @@ AutoReorderPass::ScheduleComputation(HloComputation* computation) { } if (reorder::solve_debug) { // save to pid related file - solver_->RenderGraphviz(absl::StrCat("gantt_before_", computation->name())); + solver_->SaveGraphviz(absl::StrCat("gantt_before_", computation->name())); + solver_->SaveJSON(absl::StrCat("gantt_before_", computation->name())); } - auto status = solver_->Solve(); + auto status = + solver_->Solve(absl::StrCat("mps_file_of_", computation->name())); if (reorder::solve_debug) { // save to pid related file - solver_->RenderGantt(absl::StrCat("gantt_", computation->name())); - solver_->RenderGraphviz(absl::StrCat("gantt_", computation->name())); + solver_->SaveGantt(absl::StrCat("gantt_", computation->name())); + solver_->SaveGraphviz(absl::StrCat("gantt_", computation->name())); } if (status.ok()) { diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc index 611f228e0312d..28f39778731c6 100644 --- a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc @@ -23,7 +23,31 @@ using Task = std::tuple; // (channel, processing_time), we have two // channel now:communication and computation using Job = std::vector; +namespace reorder { +uint32_t get_autoreorder_timeout() { + const char* env = std::getenv("XLA_AUTOREORDER_TIMEOUT"); + if (env == nullptr) { + return ksolveTimeout; + } + return std::atoi(env); +}; +int get_horizon(int max_time) { + // scale should be fit with module? + return max_time * 2; +} +const bool is_keep_communicate_order() { + const char* env = std::getenv("XLA_KEEP_COMMUNICATE_ORDER"); + if (env == nullptr) { + return false; + } + return std::strcmp(env, "true") == 0; +}; +int get_cpu_number() { + // return 8; + return std::thread::hardware_concurrency(); +} +} // namespace reorder template LinearProgramScheduler::~LinearProgramScheduler() { uuid2container.clear(); @@ -102,7 +126,8 @@ LPSchedulerFunc(StatusOr)::AddNodeToTask(ContainerType* node) { node_to_task_.emplace(node->UUID(), std::make_tuple(node, task)); return task; }; -LPSchedulerFunc(tsl::Status)::Solve() { + +LPSchedulerFunc(tsl::Status)::Solve(std::string mps_filename) { uint32_t max_execution_time = 0; for (auto node : nodes_) { node->Freeze(); @@ -154,9 +179,8 @@ LPSchedulerFunc(tsl::Status)::Solve() { if (IsSingleChannel(dep_type)) { auto dep_task = std::get<1>(node_to_task_.at(dep_node->UUID())); // interval - IntervalVar interval = cp_model_.NewIntervalVar( - dep_task.end, cost, - node_task.start); + IntervalVar interval = + cp_model_.NewIntervalVar(dep_task.end, cost, node_task.start); no_overlap_edges.push_back(interval); } } @@ -171,7 +195,6 @@ LPSchedulerFunc(tsl::Status)::Solve() { } cp_model_.AddMaxEquality(obj_var, ends); cp_model_.Minimize(obj_var); - // cp_model_. // VLOG(2)<<"Number of variables:"< 0) { + operations_research::MPModelProto output; + operations_research::sat::ConvertCpModelProtoToMPModelProto(model, &output); + auto status_of_string = operations_research::ExportModelAsMpsFormat(output); + if (status_of_string.ok()) { + VLOG(2) << "ExportModelAsMpsFormat success"; + std::ofstream out(absl::StrCat("/tmp/", mps_filename, ".mps")); + out << status_of_string.value(); + out.close(); + } + } + const operations_research::sat::CpSolverResponse response = - operations_research::sat::SolveWithParameters(cp_model_.Build(), - parameters); + operations_research::sat::SolveWithParameters(model, parameters); uint64_t solve_time = response.wall_time(); VLOG(1) << "Solve finish:" << response.status() << " solve time:" << solve_time; @@ -224,7 +262,7 @@ std::string ReplaceUnusedChar(const std::string str, } return result; } -LPSchedulerFunc(std::vector)::GetSortedNodes() const{ +LPSchedulerFunc(std::vector)::GetSortedNodes() const { std::vector sorted_nodes; sorted_nodes.reserve(nodes_.size()); for (auto node : nodes_) { @@ -239,7 +277,64 @@ LPSchedulerFunc(std::vector)::GetSortedNodes() const{ }); return sorted_nodes; } -LPSchedulerFunc(void)::RenderGraphviz(std::string filename) const{ +LPSchedulerFunc(void)::SaveJSON(std::string filename) const { + std::string json_file = absl::StrCat("/tmp/", filename, ".json"); + std::ofstream json_out(json_file); + json_out << "{" << std::endl; + json_out << "\"nodes\": [" << std::endl; + int32_t node_count = 0; + int32_t edge_count = 0; + + for (auto node : this->GetSortedNodes()) { + std::string name; + if (node->IsCommunication()) { + name = "communication"; + } else { + name = "compute"; + } + if (node_count > 0) { + json_out << ",\n{ \"uuid\": \"" << node->UUID() << "\",\"typename\": \"" + << name << "\", \"name\": \"" + << ReplaceUnusedChar(node->GetName(), "'") + << "\", \"cost\": " << node->GetCost() << " }"; + } else { + json_out << "{ \"uuid\": \"" << node->UUID() << "\",\"typename\": \"" + << name << "\", \"name\": \"" + << ReplaceUnusedChar(node->GetName(), "'") + << "\", \"cost\": " << node->GetCost() << " }"; + } + node_count++; + } + json_out << "]," << std::endl; + json_out << "\"edges\": [" << std::endl; + for (auto node : this->GetSortedNodes()) { + for (auto dep_pair : node->GetDeps()) { + auto dep_node = std::get<0>(dep_pair); + auto dep_cost = std::get<1>(dep_pair); + NodeType dep_type = std::get<2>(dep_pair); + std::string name; + if (IsSingleChannel(dep_type)) { + name = "communication"; + } else { + name = "compute"; + } + // draw edge + if (edge_count > 0) { + json_out << ",\n{ \"from\": \"" << dep_node->UUID() << "\", \"to\": \"" + << node->UUID() << "\", \"typename\": \"" << name + << "\", \"cost\": " << dep_cost << " }"; + } else { + json_out << "{ \"from\": \"" << dep_node->UUID() << "\", \"to\": \"" + << node->UUID() << "\", \"typename\": \"" << name + << "\", \"cost\": " << dep_cost << " }"; + } + edge_count++; + } + } + json_out << "]" << std::endl; + json_out << "}" << std::endl; +} +LPSchedulerFunc(void)::SaveGraphviz(std::string filename) const { // write a dot file std::string dot_file = absl::StrCat("/tmp/", filename, ".dot"); std::ofstream out(dot_file); @@ -287,7 +382,7 @@ LPSchedulerFunc(void)::RenderGraphviz(std::string filename) const{ auto status = system(cmd.c_str()); VLOG(4) << cmd << " execute status:" << status << std::endl; } -LPSchedulerFunc(void)::RenderGantt(std::string filename) const{ +LPSchedulerFunc(void)::SaveGantt(std::string filename) const { // https://g2.antv.antgroup.com/en/examples/storytelling/storytelling/#gantt // { name: 'compute',label:'kernel name1', startTime: 1, endTime: 4 }, VLOG(4) << "write node number:" << nodes_.size() << " to /tmp/" << filename @@ -346,7 +441,6 @@ LPSchedulerFunc(void)::RenderGantt(std::string filename) const{ chart.render();)"; } - LPContainerDAGFunc(bool)::IsIn(LPContainer* a) { return operands_.find(a) != operands_.end(); }; diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h index 5ee753a3fac3b..d6a9625eac3f5 100644 --- a/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_solver.h @@ -23,92 +23,15 @@ using CpModelBuilder = operations_research::sat::CpModelBuilder; using IntervalVar = operations_research::sat::IntervalVar; namespace reorder { const uint32_t ksolveTimeout = 180; // 3min -uint32_t get_autoreorder_timeout() { - const char* env = std::getenv("XLA_AUTOREORDER_TIMEOUT"); - if (env == nullptr) { - return ksolveTimeout; - } - return std::atoi(env); -}; -static const int kChannelNumber = 2; -int get_horizon(int max_time) { - // scale should be fit with module? - return max_time * 2; -} -bool solve_debug = true; +uint32_t get_autoreorder_timeout(); +constexpr const int kChannelNumber = 2; +int get_horizon(int max_time); +constexpr bool solve_debug = true; // TODO: no keep order will cause hung on multi processing, we should consider // how to resolve it // get cpu number of current machine -const bool is_keep_communicate_order() { - const char* env = std::getenv("XLA_KEEP_COMMUNICATE_ORDER"); - if (env == nullptr) { - return false; - } - return std::strcmp(env, "true") == 0; -}; -void save_to_cache(const std::string& content) { - const char* cache_filename = std::getenv("XLA_REORDER_CACHE_FILE"); - if (cache_filename == nullptr) { - cache_filename = "reorder.cache"; - } - std::ofstream file(cache_filename); - file << content; - file.close(); -}; -bool is_cache_enable() { - const char* cache_filename = std::getenv("XLA_REORDER_CACHE_FILE"); - if (cache_filename == nullptr) { - cache_filename = "reorder.cache"; - } - // check file exists - return std::filesystem::exists(cache_filename); -}; -std::string load_from_cache() { - const char* cache_filename = std::getenv("XLA_REORDER_CACHE_FILE"); - if (cache_filename == nullptr) { - cache_filename = "reorder.cache"; - } - - std::ifstream file(cache_filename); - std::string content; - std::string line; - while (std::getline(file, line)) { - content += line; - } - file.close(); - return content; -}; -bool accuired_reorder_lock() { - const char* lock_filename = std::getenv("XLA_REORDER_LOCK_FILE"); - if (lock_filename == nullptr) { - lock_filename = "/tmp/reorder.lock"; - } - mode_t m = umask(0); - int fd = open(lock_filename, O_RDWR | O_CREAT, 0666); - umask(m); - if (fd >= 0 && flock(fd, LOCK_EX | LOCK_NB) < 0) { - close(fd); - fd = -1; - } - return fd >= 0; -}; -void release_reorder_lock() { - const char* lock_filename = std::getenv("XLA_REORDER_LOCK_FILE"); - if (lock_filename == nullptr) { - lock_filename = "/tmp/reorder.lock"; - } - mode_t m = umask(0); - int fd = open(lock_filename, O_RDWR | O_CREAT, 0666); - umask(m); - if (fd >= 0 && flock(fd, LOCK_UN) < 0) { - close(fd); - fd = -1; - } -}; -int get_cpu_number() { - // return 8; - return std::thread::hardware_concurrency(); -} +const bool is_keep_communicate_order(); +int get_cpu_number(); } // namespace reorder enum class NodeType { kCompute = 0, @@ -307,7 +230,8 @@ class LinearProgramScheduler { // add Node to scheduler, its deps will execute before it Status AddConstraint(ContainerType* node); // solve the LP problem - Status Solve(); + // Status Solve(); + Status Solve(std::string mps_filename); // find instruction,if not exist, return error StatusOr FindInstructionLPNode(ElementType instruction); // find LPNode by instruction,if not exist,create it @@ -315,10 +239,11 @@ class LinearProgramScheduler { NodeType type); // ContainerType* std::vector GetSortedNodes() const; - // for debug: render graph viz - void RenderGraphviz(std::string filename) const; + // for debug: save graph viz file + void SaveGraphviz(std::string filename) const; // for debug: render gantt chart - void RenderGantt(std::string filename) const; + void SaveGantt(std::string filename) const; + void SaveJSON(std::string filename) const; // set max start time as horizon void SetHorizon(uint32_t horizon) { horizon_ = horizon; } StatusOr FindTask(ContainerType* node); diff --git a/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc b/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc index c871f3c976537..801b010e3e87e 100644 --- a/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc +++ b/xla/hlo/experimental/auto_reorder/auto_reorder_test.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "absl/algorithm/container.h" #include "xla/hlo/experimental/auto_reorder/auto_reorder.h" @@ -186,6 +187,8 @@ SchedulerConfig GetDefaultSchedConfig() { class AutoReorderingTest : public HloTestBase { protected: + void SetUp() override { setenv("XLA_AUTOREORDER_TIMEOUT", "60", 1); } + void TearDown() override { unsetenv("XLA_AUTOREORDER_TIMEOUT"); } const char* const add_hlo_string_ = R"( HloModule module ENTRY %elementwise { @@ -220,7 +223,7 @@ ENTRY %elementwise { HloInstruction::CreateParameter(0, input_shape, "pasync")); async_builder.AddInstruction(HloInstruction::CreateReduceScatter( output_shape, {param}, MakeReduction(type, module), - CreateReplicaGroups({{0, 1}}), false, std::nullopt, false, 0)); + CreateReplicaGroups({{0, 1}}), false, /*channel_id*/ 3, false, 0)); HloComputation* reduction = module->AddEmbeddedComputation(async_builder.Build()); @@ -533,7 +536,7 @@ ENTRY %elementwise { auto all2all_ret = builder.AddInstruction(HloInstruction::CreateAllToAll( reduce_shape, {add_reduce_scatter}, /*replica_groups=*/CreateReplicaGroups({{0, 1}}), - /*constrain_layout=*/true, /*channel_id=*/1, + /*constrain_layout=*/true, /*channel_id=*/2, /*split_dimension*/ 0)); std::vector compute_vec = {d01_dot_p0_add_p2, d23_dot_p1_add_p3, all2all_ret}; @@ -569,7 +572,8 @@ ENTRY %elementwise { VLOG(2) << "create computation begin,test name: " << TestName() << ",inst_nums=" << inst_nums << ",max_deps=" << max_deps << ",communication_rate=" << communication_rate; - HloComputation::Builder builder(TestName()); + HloComputation::Builder builder( + absl::StrCat(TestName(), "N", inst_nums, "R", communication_rate)); Shape shape = ShapeUtil::MakeShape(F32, {4, 256, 256}); // insts_list: store instruction list,which have one result @@ -959,11 +963,15 @@ ENTRY %elementwise { } }; TEST_F(AutoReorderingTest, ConvertPDO) { - GTEST_SKIP() << "using convert here;"; - + // GTEST_SKIP() << "using convert here;"; + // get filepath from env + const char* env = std::getenv("XLA_AUTOREORDER_XPLANE_DIR"); + if (env == nullptr) { + GTEST_SKIP() << "have no set XLA_AUTOREORDER_XPLANE_DIR env skip"; + } auto status = ConvertXplaneToFile( - "/root/tb/llama_xla_trace/plugins/profile/2024_05_10_17_05_39/", - "/root/tb/llama_xla_trace/llama_fdo.pbtxt"); + env, "/root/tb/llama_xla_trace_2n16g/llama_fdo.jsonl"); + std::cout << status.message() << std::endl; EXPECT_TRUE(status.ok()); } @@ -1448,7 +1456,7 @@ TEST_F(AutoReorderingTest, ReorderPassWithRandom) { auto gpu_latency_estimator = std::make_unique(); SchedulerConfig sched_config = GetDefaultSchedConfig(); auto st = MakeRandomComputation(hlo_module.get(), gpu_latency_estimator.get(), - /*inst num*/ 100, + /*inst num*/ 200, /*max deps*/ 5, /*communication rate*/ 0.2); // std::cout<ToString()<& hlo_module_info, absl::flat_hash_map* hlo_latency_info) { // Iterate events. - xplane.ForEachLine([hlo_latency_info, - hlo_module_info](const XLineVisitor& xline) { + xplane.ForEachLine([hlo_latency_info](const XLineVisitor& xline) { if (xline.DisplayName() == tsl::profiler::kXlaAsyncOpLineName) { return; } - xline.ForEachEvent([hlo_latency_info, - hlo_module_info](const XEventVisitor& xevent) { + VLOG(5) << "Processing line: " << xline.DisplayName(); + xline.ForEachEvent([hlo_latency_info](const XEventVisitor& xevent) { int64_t event_type = xevent.Type().value_or(HostEventType::kUnknownHostEventType); if (IsInternalEvent(event_type)) return; std::optional hlo_name = std::nullopt; - std::optional hlo_module_name = std::nullopt; - std::optional fingerprint = std::nullopt; - std::optional program_id = std::nullopt; auto for_each_stat = [&](const XStatVisitor& stat) { if (stat.ValueCase() == tsl::profiler::XStat::VALUE_NOT_SET) return; @@ -43,34 +51,14 @@ void GetXPlaneLatencyInfo( if (stat.Name() == GetStatTypeStr(StatType::kHloOp)) { hlo_name = stat.ToString(); } - if (stat.Name() == GetStatTypeStr(StatType::kProgramId)) { - program_id = stat.IntValue(); - } - if (stat.Name() == GetStatTypeStr(StatType::kHloModule)) { - hlo_module_name = stat.ToString(); - } }; xevent.Metadata().ForEachStat(for_each_stat); xevent.ForEachStat(for_each_stat); - if (!hlo_name.has_value() || !hlo_module_name.has_value()) { - return; - } - - if (hlo_module_name.has_value()) { - std::string fingerprint_key = hlo_module_name.value(); - if (program_id.has_value()) { - fingerprint_key = tsl::profiler::HloModuleNameWithProgramId( - hlo_module_name.value(), program_id.value()); - } - if (hlo_module_info.contains(fingerprint_key)) { - fingerprint = hlo_module_info.at(fingerprint_key); - } - } double latency = static_cast(xevent.DurationNs()) / 1e3; + VLOG(5) << "hlo_name: " << hlo_name.value_or("N/A") + << "latency:" << latency; + std::string key = hlo_name.value(); - if (fingerprint.has_value()) { - key = absl::StrCat(fingerprint.value(), kCostNameSep, hlo_name.value()); - } (*hlo_latency_info)[key].durations.emplace_back(latency); }); }); @@ -88,27 +76,103 @@ std::unique_ptr CreateModuleFromProto( return nullptr; } -std::optional GetHloModuleFingerprint( - const xla::HloModuleProto& hlo_module_proto) { +Status GetHloInstrProfileInfo( + const xla::HloModuleProto& hlo_module_proto, + absl::flat_hash_map* + hlo_module_info) { std::unique_ptr hlo_module = CreateModuleFromProto(hlo_module_proto); if (hlo_module == nullptr) { - return std::nullopt; - } - const auto& map = hlo_module->entry_computation() - ->root_instruction() - ->frontend_attributes() - .map(); - auto it = map.find("fingerprint_before_lhs"); - if (it != map.end()) { - return it->second; + return absl::InternalError("Failed to create HloModule from proto"); } - return std::nullopt; + VLOG(5) << "success get hlo module from proto"; + for (HloComputation* computation : + hlo_module->MakeNonfusionComputations({})) { + for (auto* instr : computation->instructions()) { + // instr to json + //{name:"name",opcode:"opcode",operand_count:1,operand_names:["a"],operand_types:["f32"],shape:"[1,2,3]",result_type:"f32",result_shape:"[1,2,3]",result_element_type:"f32",result_element_shape:"[1,2,3]",result_element_count:6} + // TODO: should we need shard info? + // TODO: custom call + // there are 3 category instrs: + // 1. custom call, include GEMM now; record its input shape/dtype + // 2. communicate call, include async reducescatter ; record its input + // shape/dtype + // 3. other, + HloInstructionProto instr_origin_proto = instr->ToProto(); + auto_reorder::InstrProfileInfo instr_info; + auto_reorder::Size ret_size; + instr_info.set_name(instr_origin_proto.name()); + HloOpcode code = instr->opcode(); + + instr_info.set_opcode(static_cast(code)); + + // set operand count/type/size + instr_info.set_operand_count(instr->operand_count()); + for (auto operand : instr->operands()) { + Shape op_shape = operand->shape(); + // operand dtype + + instr_info.add_operand_types( + PrimitiveType_Name(op_shape.element_type())); + auto_reorder::Size* op_size = instr_info.add_operand_sizes(); + op_size->set_rank(op_shape.dimensions_size()); + for (size_t i = 0; i < op_shape.dimensions_size(); i++) { + op_size->add_sizes(op_shape.dimensions(i)); + } + } + + Shape shape = instr->shape(); + instr_info.mutable_result_size()->set_rank(shape.dimensions_size()); + for (size_t i = 0; i < shape.dimensions_size(); i++) { + /* code */ + instr_info.mutable_result_size()->add_sizes(shape.dimensions(i)); + } + // custom call + switch (code) { + case HloOpcode::kCustomCall: { + instr_info.set_custom_call_target(instr->custom_call_target()); + break; + } + case HloOpcode::kReduceScatter: + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherStart: + case HloOpcode::kAllReduce: + case HloOpcode::kAllReduceStart: + case HloOpcode::kCollectivePermuteStart: { // comm op need record + // process group + // example :{{1,2,3,4}}, {{1,2},{3,4}} + std::vector replica_groups = instr->replica_groups(); + uint16_t group_id = 0; + for (auto replica_group : replica_groups) { + xla::auto_reorder::ReplicaGroup* group = + instr_info.add_process_groups(); + group->set_replica_group_id(group_id); + group_id++; + for (auto replica : replica_group.replica_ids()) { + group->add_replica_ids(replica); + } + } + + // instr_info.set_process_group(); + break; + } + case HloOpcode::kAsyncStart: { + // get async inner instr + } + default: + break; + + } // end switch + hlo_module_info->emplace(instr_origin_proto.name(), instr_info); + } // end for instrs + } // end for computations + return absl::OkStatus(); } -void GetXPlaneHloModuleInfo( +void GetXPlaneHloModuleProfileInfo( const XPlaneVisitor& xplane, - absl::flat_hash_map* hlo_module_info) { + absl::flat_hash_map* + hlo_module_info) { // Iterate events. xplane.ForEachEventMetadata([&](const XEventMetadataVisitor& event_metadata) { event_metadata.ForEachStat([&](const XStatVisitor& stat) { @@ -117,65 +181,102 @@ void GetXPlaneHloModuleInfo( stat.BytesValue().size())) { const xla::HloModuleProto& hlo_module_proto = hlo_proto.hlo_module(); - std::optional fingerprint = - GetHloModuleFingerprint(hlo_module_proto); - if (fingerprint.has_value()) { - std::string key_with_id = tsl::profiler::HloModuleNameWithProgramId( - hlo_module_proto.name(), hlo_module_proto.id()); - (*hlo_module_info)[key_with_id] = fingerprint.value(); + Status st = GetHloInstrProfileInfo(hlo_module_proto, hlo_module_info); + if (!st.ok()) { + VLOG(5) << "Failed to get HloInstrProfileInfo from HloModuleProto"; } } }); }); } -Status ConvertXplaneToProfiledInstructionsProto( +Status ConvertXplaneToProfiledJSONLine( std::vector xspaces, - tensorflow::profiler::ProfiledInstructionsProto* - profiled_instructions_proto) { + std::vector* jsonline_vector) { + // name to HloLatencyInfo absl::flat_hash_map hlo_latency_info; - absl::flat_hash_map hlo_module_info; + // name to HloInstructionProto + absl::flat_hash_map + hlo_instr_profile_info; + google::protobuf::util::JsonPrintOptions options; + options.add_whitespace = true; + options.always_print_primitive_fields = true; + google::protobuf::util::Status st; + // st = google::protobuf::util::MessageToJsonString(profile_proto, + // &json_string, options); if(!st.ok()) { + // return absl::InternalError("Failed to convert ProfiledInstructionsProto + // to json"); + // } // Iterate through each host. for (const XSpace& xspace : xspaces) { const XPlane* metadata_plane = FindPlaneWithName(xspace, tsl::profiler::kMetadataPlaneName); if (metadata_plane != nullptr) { XPlaneVisitor xplane = CreateTfXPlaneVisitor(metadata_plane); - GetXPlaneHloModuleInfo(xplane, &hlo_module_info); + GetXPlaneHloModuleProfileInfo(xplane, &hlo_instr_profile_info); } std::vector device_planes = FindPlanesWithPrefix(xspace, tsl::profiler::kGpuPlanePrefix); // We don't expect GPU and TPU planes and custom devices to be present in // the same XSpace. if (device_planes.empty()) { + VLOG(5) << "No GPU plane found, try to find TPU plane."; device_planes = FindPlanesWithPrefix(xspace, tsl::profiler::kTpuPlanePrefix); } if (device_planes.empty()) { + VLOG(5) << "No TPU plane found, try to find custom device plane."; device_planes = FindPlanesWithPrefix(xspace, tsl::profiler::kCustomPlanePrefix); } // Go over each device plane. for (const XPlane* device_plane : device_planes) { XPlaneVisitor xplane = CreateTfXPlaneVisitor(device_plane); - GetXPlaneLatencyInfo(xplane, hlo_module_info, &hlo_latency_info); + GetXPlaneLatencyInfo(xplane, &hlo_latency_info); } } + if (hlo_instr_profile_info.empty()) { + VLOG(5) << "No HLO instruction info found in xplane protobuf."; + return absl::InternalError("No HLO latency info found in xplane"); + } + if (hlo_latency_info.empty()) { + VLOG(5) << "No HLO latency info found in xplane."; + return absl::InternalError("No HLO latency info found in xplane"); + } + HloLatencyStats stats; // Get the mean duration for each hlo and store into the proto. for (const auto& iter : hlo_latency_info) { - auto* cost = profiled_instructions_proto->add_costs(); - std::vector durations = iter.second.durations; - double sum = std::accumulate(durations.begin(), durations.end(), 0.0); - cost->set_cost_us(sum / durations.size()); - cost->set_name(iter.first); - } + // auto* cost = profiled_instructions_proto->add_costs(); + auto profile_it = hlo_instr_profile_info.find(iter.first); + if (profile_it == hlo_instr_profile_info.end()) { + VLOG(5) << "No instr info found for instr: " << iter.first; + stats.misses++; + continue; + } else { + stats.hits++; + } + auto_reorder::InstrProfileInfo cost = profile_it->second; + for (auto duration : iter.second.durations) { + // cost->add_durations(d); + cost.set_cost(duration); + std::string json_string; + auto st = google::protobuf::util::MessageToJsonString(cost, &json_string, + options); + if (!st.ok()) { + return absl::InternalError( + "Failed to convert ProfiledInstructionsProto to json"); + } + jsonline_vector->push_back(json_string); + } + } + VLOG(5) << "Lookup inst profiler, Hits: " << stats.hits + << " Misses: " << stats.misses; return OkStatus(); } Status ConvertXplaneUnderLogdirToProfiledInstructionsProto( - const std::string& logdir, tensorflow::profiler::ProfiledInstructionsProto* - profiled_instructions_proto) { + const std::string& logdir, std::vector* jsonline_vector) { // Find the xplane files for each host under logdir. std::vector children_path; TF_RETURN_IF_ERROR(tsl::Env::Default()->GetChildren(logdir, &children_path)); @@ -198,21 +299,26 @@ Status ConvertXplaneUnderLogdirToProfiledInstructionsProto( absl::StrCat("Could not find xplane file under: ", logdir)); } VLOG(3) << "Have load " << xspaces.size() << " xspaces"; - return ConvertXplaneToProfiledInstructionsProto(xspaces, - profiled_instructions_proto); + return ConvertXplaneToProfiledJSONLine(xspaces, jsonline_vector); } Status ConvertXplaneToFile(const std::string& xplane_dir, const std::string& output_filename) { tensorflow::profiler::ProfiledInstructionsProto profile_proto; + std::vector jsonline_vector; auto status = ConvertXplaneUnderLogdirToProfiledInstructionsProto( - xplane_dir, &profile_proto); + xplane_dir, &jsonline_vector); if (!status.ok()) { return status; } - std::string profile_proto_str = profile_proto.SerializeAsString(); - TF_RETURN_IF_ERROR(tsl::WriteStringToFile( - tsl::Env::Default(), output_filename, profile_proto_str)); + // open file,write jsonline + std::ofstream fout = std::ofstream(output_filename); + if (!fout.is_open()) { + return absl::InternalError("Failed to open file for writing"); + } + for (const std::string& jsonline : jsonline_vector) { + fout << jsonline << std::endl; + } return OkStatus(); } diff --git a/xla/hlo/experimental/auto_reorder/convert_xplane.h b/xla/hlo/experimental/auto_reorder/convert_xplane.h index 0c1f350d25012..46e0f382b1dbe 100644 --- a/xla/hlo/experimental/auto_reorder/convert_xplane.h +++ b/xla/hlo/experimental/auto_reorder/convert_xplane.h @@ -9,6 +9,9 @@ #include #include +#include +#include + #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/match.h" @@ -18,6 +21,7 @@ // #include "xla/status.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo.pb.h" +#include "xla/primitive_util.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" @@ -31,6 +35,8 @@ #include "tsl/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/profiled_instructions.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "google/protobuf/util/json_util.h" +#include "xla/hlo/experimental/auto_reorder/instr_profile_info.pb.h" namespace xla { @@ -53,11 +59,15 @@ using tsl::profiler::XLineVisitor; using tsl::profiler::XPlaneVisitor; using tsl::profiler::XStatVisitor; -// Latency info for a single HLO instruction. +// Latency info for a single HLO instruction. it's a vector of durations. Each +// duration is the latency of the instruction struct HloLatencyInfo { std::vector durations; }; - +struct HloLatencyStats { + uint32_t hits; + uint32_t misses; +}; Status ConvertXplaneToProfiledInstructionsProto( std::vector xspaces, tensorflow::profiler::ProfiledInstructionsProto* @@ -69,5 +79,4 @@ Status ConvertXplaneUnderLogdirToProfiledInstructionsProto( Status ConvertXplaneToFile(const std::string& xplane_dir, const std::string& output_filename); - -} //namespace xla \ No newline at end of file +} // namespace xla \ No newline at end of file diff --git a/xla/hlo/experimental/auto_reorder/instr_profile_info.proto b/xla/hlo/experimental/auto_reorder/instr_profile_info.proto new file mode 100644 index 0000000000000..afe8162aa6837 --- /dev/null +++ b/xla/hlo/experimental/auto_reorder/instr_profile_info.proto @@ -0,0 +1,42 @@ +// Copyright 2023 The Lynx Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +package xla.auto_reorder; + +import "google/protobuf/any.proto"; +// Size(rank=3,sizes=[2,3,4]) +message Size { + int64 rank = 1; + repeated int64 sizes = 2; +} +//ReplicaGroup(replica_group_id=1, replica_ids={1,2}) +message ReplicaGroup{ + int64 replica_group_id=1; + repeated int64 replica_ids=2; +} +// as xla/service/hlo.proto HloInstructionProto subset,we focus on compute/communicate complexity +message InstrProfileInfo { + string name = 1; + uint32 operand_count=2; + uint32 result_count=3; + uint32 opcode=4; + uint32 version=5; + repeated string operand_types = 6; + repeated string result_types = 7; + repeated Size operand_sizes = 8; + Size result_size = 9; + repeated ReplicaGroup process_groups=10; + optional string custom_call_target = 11; + double cost=12; +} \ No newline at end of file diff --git a/xla/hlo/utils/common_ortools_deps.h b/xla/hlo/utils/common_ortools_deps.h index 9d26d32cbad38..46a0de861d000 100644 --- a/xla/hlo/utils/common_ortools_deps.h +++ b/xla/hlo/utils/common_ortools_deps.h @@ -2,8 +2,10 @@ #define ORTOOLS_LINEAR_SOLVER_H #include "ortools/linear_solver/linear_solver.h" #include "ortools/linear_solver/linear_solver.pb.h" +#include "ortools/linear_solver/model_exporter.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" +#include "ortools/sat/lp_utils.h" #include "absl/strings/string_view.h" #endif \ No newline at end of file diff --git a/xla/service/gpu/gpu_hlo_schedule.cc b/xla/service/gpu/gpu_hlo_schedule.cc index 6434af209d138..0f5d9809ebae7 100644 --- a/xla/service/gpu/gpu_hlo_schedule.cc +++ b/xla/service/gpu/gpu_hlo_schedule.cc @@ -685,11 +685,15 @@ absl::Status IsProfileApplicable( instruction_names.insert(instr->name()); } } - + int64_t total_instruction_count = instruction_names.size(); + int64_t cost_miss_count; + int64_t cost_hit_count; for (const auto& cost : profile.costs()) { if (!instruction_names.contains(cost.name())) { + cost_miss_count++; // profile inst name not in this module return absl::InvalidArgumentError(absl::StrFormat( "cost name %s not in module %s", cost.name(), module->name())); + } else { } } for (const auto& latency : profile.latencies()) {