Skip to content

Commit

Permalink
support export to mps and json; [WIP] convert xplant to offline sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
zjjott committed May 21, 2024
1 parent 0356daa commit 8ada939
Show file tree
Hide file tree
Showing 10 changed files with 393 additions and 187 deletions.
16 changes: 15 additions & 1 deletion xla/hlo/experimental/auto_reorder/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ cc_library(
":auto_reorder_solver"
],
)
tf_proto_library(
name = "instr_profile_info_proto",
srcs = ["instr_profile_info.proto"],
)

cc_library(
name="convert_xplane",
srcs=["convert_xplane.cc"],
Expand All @@ -97,6 +102,7 @@ cc_library(
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_proto_cc",
"//xla:shape_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand All @@ -111,6 +117,8 @@ cc_library(
"@tsl//tsl/profiler/utils:xplane_schema",
"@tsl//tsl/profiler/utils:xplane_utils",
"@tsl//tsl/profiler/utils:xplane_visitor",
"@com_google_protobuf//:protobuf",
":instr_profile_info_proto_cc"
]
)
xla_cc_test(
Expand All @@ -135,16 +143,22 @@ xla_cc_test(
],
)

cc_binary(
xla_cc_binary(
name="convert_xplane_tools",
linkopts = [
"-Wl,--allow-multiple-definition",
"-lstdc++fs", # For std::filesystem
],
srcs=["convert_xplane_bin.cc"],
deps=[
":convert_xplane",
":auto_reorder",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_parser",
"//xla/service:latency_hiding_scheduler",
"//xla/service/gpu:gpu_hlo_schedule",
"//xla/service:gpu_plugin",
"//xla:device_util",
"@com_google_absl//absl/log",
"@tsl//tsl/platform:statusor",

Expand Down
10 changes: 6 additions & 4 deletions xla/hlo/experimental/auto_reorder/auto_reorder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,15 @@ AutoReorderPass::ScheduleComputation(HloComputation* computation) {
}
if (reorder::solve_debug) {
// save to pid related file
solver_->RenderGraphviz(absl::StrCat("gantt_before_", computation->name()));
solver_->SaveGraphviz(absl::StrCat("gantt_before_", computation->name()));
solver_->SaveJSON(absl::StrCat("gantt_before_", computation->name()));
}
auto status = solver_->Solve();
auto status =
solver_->Solve(absl::StrCat("mps_file_of_", computation->name()));
if (reorder::solve_debug) {
// save to pid related file
solver_->RenderGantt(absl::StrCat("gantt_", computation->name()));
solver_->RenderGraphviz(absl::StrCat("gantt_", computation->name()));
solver_->SaveGantt(absl::StrCat("gantt_", computation->name()));
solver_->SaveGraphviz(absl::StrCat("gantt_", computation->name()));
}

if (status.ok()) {
Expand Down
116 changes: 105 additions & 11 deletions xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,31 @@ using Task =
std::tuple<int8_t, CostType>; // (channel, processing_time), we have two
// channel now:communication and computation
using Job = std::vector<Task>;
namespace reorder {
uint32_t get_autoreorder_timeout() {
const char* env = std::getenv("XLA_AUTOREORDER_TIMEOUT");
if (env == nullptr) {
return ksolveTimeout;
}
return std::atoi(env);
};
int get_horizon(int max_time) {
// scale should be fit with module?
return max_time * 2;
}
const bool is_keep_communicate_order() {
const char* env = std::getenv("XLA_KEEP_COMMUNICATE_ORDER");
if (env == nullptr) {
return false;
}
return std::strcmp(env, "true") == 0;
};
int get_cpu_number() {
// return 8;
return std::thread::hardware_concurrency();
}

} // namespace reorder
template <typename ContainerType, typename ElementType>
LinearProgramScheduler<ContainerType, ElementType>::~LinearProgramScheduler() {
uuid2container.clear();
Expand Down Expand Up @@ -102,7 +126,8 @@ LPSchedulerFunc(StatusOr<TaskType>)::AddNodeToTask(ContainerType* node) {
node_to_task_.emplace(node->UUID(), std::make_tuple(node, task));
return task;
};
LPSchedulerFunc(tsl::Status)::Solve() {

LPSchedulerFunc(tsl::Status)::Solve(std::string mps_filename) {
uint32_t max_execution_time = 0;
for (auto node : nodes_) {
node->Freeze();
Expand Down Expand Up @@ -154,9 +179,8 @@ LPSchedulerFunc(tsl::Status)::Solve() {
if (IsSingleChannel(dep_type)) {
auto dep_task = std::get<1>(node_to_task_.at(dep_node->UUID()));
// interval
IntervalVar interval = cp_model_.NewIntervalVar(
dep_task.end, cost,
node_task.start);
IntervalVar interval =
cp_model_.NewIntervalVar(dep_task.end, cost, node_task.start);
no_overlap_edges.push_back(interval);
}
}
Expand All @@ -171,7 +195,6 @@ LPSchedulerFunc(tsl::Status)::Solve() {
}
cp_model_.AddMaxEquality(obj_var, ends);
cp_model_.Minimize(obj_var);

// cp_model_.
// VLOG(2)<<"Number of variables:"<<cp_model_.NumVariables()<<" Number of
// constraint:"<<cp_model_.NumConstraints();
Expand All @@ -188,9 +211,24 @@ LPSchedulerFunc(tsl::Status)::Solve() {
// parameters.set_log_search_progress(true);
}
parameters.set_num_search_workers(1);
auto model = cp_model_.Build();
// model is operations_research::sat::CpModelProto type
// need operations_research::MPModelProto& type, so we need to convert it
// model
if (mps_filename.size() > 0) {
operations_research::MPModelProto output;
operations_research::sat::ConvertCpModelProtoToMPModelProto(model, &output);
auto status_of_string = operations_research::ExportModelAsMpsFormat(output);
if (status_of_string.ok()) {
VLOG(2) << "ExportModelAsMpsFormat success";
std::ofstream out(absl::StrCat("/tmp/", mps_filename, ".mps"));
out << status_of_string.value();
out.close();
}
}

const operations_research::sat::CpSolverResponse response =
operations_research::sat::SolveWithParameters(cp_model_.Build(),
parameters);
operations_research::sat::SolveWithParameters(model, parameters);
uint64_t solve_time = response.wall_time();
VLOG(1) << "Solve finish:" << response.status()
<< " solve time:" << solve_time;
Expand Down Expand Up @@ -224,7 +262,7 @@ std::string ReplaceUnusedChar(const std::string str,
}
return result;
}
LPSchedulerFunc(std::vector<ContainerType*>)::GetSortedNodes() const{
LPSchedulerFunc(std::vector<ContainerType*>)::GetSortedNodes() const {
std::vector<ContainerType*> sorted_nodes;
sorted_nodes.reserve(nodes_.size());
for (auto node : nodes_) {
Expand All @@ -239,7 +277,64 @@ LPSchedulerFunc(std::vector<ContainerType*>)::GetSortedNodes() const{
});
return sorted_nodes;
}
LPSchedulerFunc(void)::RenderGraphviz(std::string filename) const{
LPSchedulerFunc(void)::SaveJSON(std::string filename) const {
std::string json_file = absl::StrCat("/tmp/", filename, ".json");
std::ofstream json_out(json_file);
json_out << "{" << std::endl;
json_out << "\"nodes\": [" << std::endl;
int32_t node_count = 0;
int32_t edge_count = 0;

for (auto node : this->GetSortedNodes()) {
std::string name;
if (node->IsCommunication()) {
name = "communication";
} else {
name = "compute";
}
if (node_count > 0) {
json_out << ",\n{ \"uuid\": \"" << node->UUID() << "\",\"typename\": \""
<< name << "\", \"name\": \""
<< ReplaceUnusedChar(node->GetName(), "'")
<< "\", \"cost\": " << node->GetCost() << " }";
} else {
json_out << "{ \"uuid\": \"" << node->UUID() << "\",\"typename\": \""
<< name << "\", \"name\": \""
<< ReplaceUnusedChar(node->GetName(), "'")
<< "\", \"cost\": " << node->GetCost() << " }";
}
node_count++;
}
json_out << "]," << std::endl;
json_out << "\"edges\": [" << std::endl;
for (auto node : this->GetSortedNodes()) {
for (auto dep_pair : node->GetDeps()) {
auto dep_node = std::get<0>(dep_pair);
auto dep_cost = std::get<1>(dep_pair);
NodeType dep_type = std::get<2>(dep_pair);
std::string name;
if (IsSingleChannel(dep_type)) {
name = "communication";
} else {
name = "compute";
}
// draw edge
if (edge_count > 0) {
json_out << ",\n{ \"from\": \"" << dep_node->UUID() << "\", \"to\": \""
<< node->UUID() << "\", \"typename\": \"" << name
<< "\", \"cost\": " << dep_cost << " }";
} else {
json_out << "{ \"from\": \"" << dep_node->UUID() << "\", \"to\": \""
<< node->UUID() << "\", \"typename\": \"" << name
<< "\", \"cost\": " << dep_cost << " }";
}
edge_count++;
}
}
json_out << "]" << std::endl;
json_out << "}" << std::endl;
}
LPSchedulerFunc(void)::SaveGraphviz(std::string filename) const {
// write a dot file
std::string dot_file = absl::StrCat("/tmp/", filename, ".dot");
std::ofstream out(dot_file);
Expand Down Expand Up @@ -287,7 +382,7 @@ LPSchedulerFunc(void)::RenderGraphviz(std::string filename) const{
auto status = system(cmd.c_str());
VLOG(4) << cmd << " execute status:" << status << std::endl;
}
LPSchedulerFunc(void)::RenderGantt(std::string filename) const{
LPSchedulerFunc(void)::SaveGantt(std::string filename) const {
// https://g2.antv.antgroup.com/en/examples/storytelling/storytelling/#gantt
// { name: 'compute',label:'kernel name1', startTime: 1, endTime: 4 },
VLOG(4) << "write node number:" << nodes_.size() << " to /tmp/" << filename
Expand Down Expand Up @@ -346,7 +441,6 @@ LPSchedulerFunc(void)::RenderGantt(std::string filename) const{
chart.render();)";
}


LPContainerDAGFunc(bool)::IsIn(LPContainer<ElementType>* a) {
return operands_.find(a) != operands_.end();
};
Expand Down
99 changes: 12 additions & 87 deletions xla/hlo/experimental/auto_reorder/auto_reorder_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,92 +23,15 @@ using CpModelBuilder = operations_research::sat::CpModelBuilder;
using IntervalVar = operations_research::sat::IntervalVar;
namespace reorder {
const uint32_t ksolveTimeout = 180; // 3min
uint32_t get_autoreorder_timeout() {
const char* env = std::getenv("XLA_AUTOREORDER_TIMEOUT");
if (env == nullptr) {
return ksolveTimeout;
}
return std::atoi(env);
};
static const int kChannelNumber = 2;
int get_horizon(int max_time) {
// scale should be fit with module?
return max_time * 2;
}
bool solve_debug = true;
uint32_t get_autoreorder_timeout();
constexpr const int kChannelNumber = 2;
int get_horizon(int max_time);
constexpr bool solve_debug = true;
// TODO: no keep order will cause hung on multi processing, we should consider
// how to resolve it
// get cpu number of current machine
const bool is_keep_communicate_order() {
const char* env = std::getenv("XLA_KEEP_COMMUNICATE_ORDER");
if (env == nullptr) {
return false;
}
return std::strcmp(env, "true") == 0;
};
void save_to_cache(const std::string& content) {
const char* cache_filename = std::getenv("XLA_REORDER_CACHE_FILE");
if (cache_filename == nullptr) {
cache_filename = "reorder.cache";
}
std::ofstream file(cache_filename);
file << content;
file.close();
};
bool is_cache_enable() {
const char* cache_filename = std::getenv("XLA_REORDER_CACHE_FILE");
if (cache_filename == nullptr) {
cache_filename = "reorder.cache";
}
// check file exists
return std::filesystem::exists(cache_filename);
};
std::string load_from_cache() {
const char* cache_filename = std::getenv("XLA_REORDER_CACHE_FILE");
if (cache_filename == nullptr) {
cache_filename = "reorder.cache";
}

std::ifstream file(cache_filename);
std::string content;
std::string line;
while (std::getline(file, line)) {
content += line;
}
file.close();
return content;
};
bool accuired_reorder_lock() {
const char* lock_filename = std::getenv("XLA_REORDER_LOCK_FILE");
if (lock_filename == nullptr) {
lock_filename = "/tmp/reorder.lock";
}
mode_t m = umask(0);
int fd = open(lock_filename, O_RDWR | O_CREAT, 0666);
umask(m);
if (fd >= 0 && flock(fd, LOCK_EX | LOCK_NB) < 0) {
close(fd);
fd = -1;
}
return fd >= 0;
};
void release_reorder_lock() {
const char* lock_filename = std::getenv("XLA_REORDER_LOCK_FILE");
if (lock_filename == nullptr) {
lock_filename = "/tmp/reorder.lock";
}
mode_t m = umask(0);
int fd = open(lock_filename, O_RDWR | O_CREAT, 0666);
umask(m);
if (fd >= 0 && flock(fd, LOCK_UN) < 0) {
close(fd);
fd = -1;
}
};
int get_cpu_number() {
// return 8;
return std::thread::hardware_concurrency();
}
const bool is_keep_communicate_order();
int get_cpu_number();
} // namespace reorder
enum class NodeType {
kCompute = 0,
Expand Down Expand Up @@ -307,18 +230,20 @@ class LinearProgramScheduler {
// add Node to scheduler, its deps will execute before it
Status AddConstraint(ContainerType* node);
// solve the LP problem
Status Solve();
// Status Solve();
Status Solve(std::string mps_filename);
// find instruction,if not exist, return error
StatusOr<ContainerType*> FindInstructionLPNode(ElementType instruction);
// find LPNode by instruction,if not exist,create it
ContainerType* FindLPNodeOrCreate(ElementType instruction, CostType cost,
NodeType type);
// ContainerType*
std::vector<ContainerType*> GetSortedNodes() const;
// for debug: render graph viz
void RenderGraphviz(std::string filename) const;
// for debug: save graph viz file
void SaveGraphviz(std::string filename) const;
// for debug: render gantt chart
void RenderGantt(std::string filename) const;
void SaveGantt(std::string filename) const;
void SaveJSON(std::string filename) const;
// set max start time as horizon
void SetHorizon(uint32_t horizon) { horizon_ = horizon; }
StatusOr<TaskType> FindTask(ContainerType* node);
Expand Down
Loading

0 comments on commit 8ada939

Please sign in to comment.