diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 3a21694b61b8cc..391f19dabdb82f 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -817,6 +817,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -838,6 +839,7 @@ xla_cc_test( "nomsan", ], deps = [ + ":mlir_to_hlo", ":pjrt_api", ":pjrt_c_api_client", ":pjrt_client", @@ -848,12 +850,16 @@ xla_cc_test( "//xla:shape_util", "//xla/hlo/builder:xla_builder", "//xla/pjrt/c:pjrt_c_api_cpu_internal", + "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/tests:literal_test_util", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", + "@stablehlo//:version", ], ) diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 81c250b14c5b1f..e6cf911d7b7a04 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -66,7 +67,6 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/framework/allocator.h" @@ -393,11 +393,20 @@ absl::StatusOr> PjRtCApiClient::Compile( absl::StatusOr> PjRtCApiClient::Compile( mlir::ModuleOp module, CompileOptions options) { if (!pjrt_c_api()) llvm::report_fatal_error("pjrt_c_api is null"); - TF_ASSIGN_OR_RETURN( - std::string serialized, - xla::Serialize(module, - xla::GetDefaultStablehloVersion( - plugin_attributes()->pjrt_c_api_minor_version))); + + auto attributes = plugin_attributes()->attributes; + std::string version_string; + auto version = attributes.find("stablehlo_current_version"); + if (version != attributes.end()) { + std::vector v = std::get>(version->second); + version_string = absl::StrFormat("%d.%d.%d", v[0], v[1], v[2]); + } else { + version_string = xla::GetDefaultStablehloVersion( + plugin_attributes()->pjrt_c_api_minor_version); + } + + TF_ASSIGN_OR_RETURN(std::string serialized, + xla::Serialize(module, version_string)); std::string format(pjrt::kMlirFormat); return InitializeArgsAndCompile(this, c_api_, c_client_.get(), options, serialized, format); diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc index 993c89da713e5e..7963a5d2e2ec4f 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc @@ -25,10 +25,17 @@ limitations under the License. #include #include #include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "stablehlo/dialect/Version.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_cpu_internal.h" +#include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" @@ -163,5 +170,31 @@ TEST(PjRtClientTest, CreateViewAndCopyToDeviceAsyncExternalCpuOnly) { *literal)); } +TEST(PjRtClientTest, CompileUsesStableHloVersion) { + SetUpCpuPjRtApi(); + TF_ASSERT_OK_AND_ASSIGN(const PJRT_Api* c_api, pjrt::PjrtApi("cpu")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetCApiClient("cpu")); + static auto PJRT_Client_Compile_Orig = c_api->PJRT_Client_Compile; + constexpr char kProgram[] = "func.func @main() {return}"; + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, + ParseMlirModuleString(kProgram, context)); + const_cast(c_api)->PJRT_Client_Compile = + [](PJRT_Client_Compile_Args* args) -> PJRT_Error* { + mlir::vhlo::Version version = mlir::vhlo::Version::getCurrentVersion(); + std::string version_string = absl::StrFormat( + "%d.%d.%d", version.getMajor(), version.getMinor(), version.getPatch()); + // MLIR doesn't have any functionality for retrieving the producer of + // bytecode files, so just scan the raw string. + EXPECT_TRUE(llvm::StringRef(args->program->code, args->program->code_size) + .contains(version_string)); + return PJRT_Client_Compile_Orig(args); + }; + std::unique_ptr executable = + client->Compile(*module, CompileOptions()).value(); + const_cast(c_api)->PJRT_Client_Compile = PJRT_Client_Compile_Orig; +} + } // namespace } // namespace xla