diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 2736a49cacb762..64f0fd64bf1912 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -438,6 +438,7 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":hlo_graph_dumper", + ":hlo_proto_cc", ":hlo_proto_util", "//xla:util", "//xla:xla_proto_cc", diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc index b9c5857d25ceed..eaf5066bf31e80 100644 --- a/third_party/xla/xla/service/dump.cc +++ b/third_party/xla/xla/service/dump.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_proto_util.h" #include "xla/tsl/lib/io/zlib_compression_options.h" @@ -884,6 +885,36 @@ void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot, DumpToFileInDirImpl(filename, pb, canonical_opts); } +void DumpHloUnoptimizedSnapshotIfEnabled( + const HloUnoptimizedSnapshot& hlo_snapshot, const DebugOptions& opts) { + CanonicalDebugOptions canonical_opts(opts); + std::string name = hlo_snapshot.hlo_module().name(); + if (!canonical_opts.should_dump_module(name) || + !canonical_opts.dump_unoptimized_snapshots) { + return; + } + + if (hlo_snapshot.partitions_size() == 0) { + LOG(ERROR) << "Refusing to write unoptimized HLO snapshot proto for module " + << name << ": no partitions input found."; + return; + } + int64_t execution_count; + { + static absl::Mutex mu(absl::kConstInit); + static auto& module_id_to_execution_count ABSL_GUARDED_BY(mu) = + *new absl::flat_hash_map(); + absl::MutexLock lock(&mu); + execution_count = + module_id_to_execution_count[hlo_snapshot.hlo_module().id()]++; + } + std::string filename = FilenameFor( + hlo_snapshot.hlo_module().id(), hlo_snapshot.hlo_module().name(), "", + absl::StrFormat("execution_%04d.hlo_unoptimized_snapshot", + execution_count)); + DumpProtobufToFile(hlo_snapshot, opts, filename, nullptr); +} + void DumpHloModuleMetadataIfEnabled(const std::vector& modules) { absl::flat_hash_set dumped_module_ids; for (const HloModule* module : modules) { diff --git a/third_party/xla/xla/service/dump.h b/third_party/xla/xla/service/dump.h index 623e7298fb9306..263da1af4c4c64 100644 --- a/third_party/xla/xla/service/dump.h +++ b/third_party/xla/xla/service/dump.h @@ -152,6 +152,11 @@ void DumpHloSnapshotIfEnabled(const HloModule& module, void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot, const DebugOptions& opts); +// Dumps the given HloUnoptimisedSnapshot to the module's xla_dump_dir, if this +// is enabled. +void DumpHloUnoptimizedSnapshotIfEnabled( + const HloUnoptimizedSnapshot& hlo_snapshot, const DebugOptions& opts); + void DumpHloModuleMetadataIfEnabled(const std::vector& modules); // Returns true if we should dump data for an HloModule. This is useful if you diff --git a/third_party/xla/xla/service/dump_test.cc b/third_party/xla/xla/service/dump_test.cc index 7d8e6d79b3cbbb..65a7cbbe7f0102 100644 --- a/third_party/xla/xla/service/dump_test.cc +++ b/third_party/xla/xla/service/dump_test.cc @@ -181,5 +181,34 @@ TEST(DumpTest, DumpFdoProfileToFileWhenEnabled) { EXPECT_TRUE(absl::StrContains(data, fdo_profile)); } +TEST(DumpTest, DumpHloUnoptimizedSnapshot) { + HloUnoptimizedSnapshot hlo_snapshot; + HloModuleProto module; + module.set_name("hello"); + *hlo_snapshot.mutable_hlo_module() = module; + *hlo_snapshot.add_partitions() = HloInputs(); + + HloModuleConfig config; + DebugOptions options = config.debug_options(); + + options.set_xla_dump_to(tsl::testing::TmpDir()); + options.set_xla_dump_hlo_as_text(true); + options.set_xla_gpu_dump_hlo_unoptimized_snapshots(true); + config.set_debug_options(options); + + DumpHloUnoptimizedSnapshotIfEnabled(hlo_snapshot, options); + + std::vector matches; + std::string pattern_filename = + tsl::io::JoinPath(tsl::testing::TmpDir(), "*hlo_unoptimized_snapshot*"); + TF_ASSERT_OK( + tsl::Env::Default()->GetMatchingPaths(pattern_filename, &matches)); + EXPECT_THAT(matches, Not(IsEmpty())); + + HloUnoptimizedSnapshot hlo_snapshot_loaded; + TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), matches.front(), + &hlo_snapshot_loaded)); + EXPECT_EQ(hlo_snapshot_loaded.hlo_module().name(), module.name()); +} } // namespace } // namespace xla