Skip to content

Commit

Permalink
Use the plugin's preferred StableHLO version.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688135422
  • Loading branch information
matthiaskramm authored and tensorflower-gardener committed Oct 21, 2024
1 parent 8699ab1 commit c04a132
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
6 changes: 6 additions & 0 deletions third_party/xla/xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
Expand All @@ -838,6 +839,7 @@ xla_cc_test(
"nomsan",
],
deps = [
":mlir_to_hlo",
":pjrt_api",
":pjrt_c_api_client",
":pjrt_client",
Expand All @@ -848,12 +850,16 @@ xla_cc_test(
"//xla:shape_util",
"//xla/hlo/builder:xla_builder",
"//xla/pjrt/c:pjrt_c_api_cpu_internal",
"//xla/pjrt/c:pjrt_c_api_hdrs",
"//xla/tests:literal_test_util",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_googletest//:gtest_main",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
"@stablehlo//:version",
],
)

Expand Down
21 changes: 15 additions & 6 deletions third_party/xla/xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -66,7 +67,6 @@ limitations under the License.
#include "xla/pjrt/pjrt_layout.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/framework/allocator.h"
Expand Down Expand Up @@ -393,11 +393,20 @@ absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::Compile(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::Compile(
mlir::ModuleOp module, CompileOptions options) {
if (!pjrt_c_api()) llvm::report_fatal_error("pjrt_c_api is null");
TF_ASSIGN_OR_RETURN(
std::string serialized,
xla::Serialize(module,
xla::GetDefaultStablehloVersion(
plugin_attributes()->pjrt_c_api_minor_version)));

auto attributes = plugin_attributes()->attributes;
std::string version_string;
auto version = attributes.find("stablehlo_current_version");
if (version != attributes.end()) {
std::vector<int64_t> v = std::get<std::vector<int64_t>>(version->second);
version_string = absl::StrFormat("%d.%d.%d", v[0], v[1], v[2]);
} else {
version_string = xla::GetDefaultStablehloVersion(
plugin_attributes()->pjrt_c_api_minor_version);
}

TF_ASSIGN_OR_RETURN(std::string serialized,
xla::Serialize(module, version_string));
std::string format(pjrt::kMlirFormat);
return InitializeArgsAndCompile(this, c_api_, c_client_.get(), options,
serialized, format);
Expand Down
33 changes: 33 additions & 0 deletions third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "stablehlo/dialect/Version.h"
#include "xla/cpu_function_runtime.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/literal_util.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_cpu_internal.h"
#include "xla/pjrt/mlir_to_hlo.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
Expand Down Expand Up @@ -163,5 +170,31 @@ TEST(PjRtClientTest, CreateViewAndCopyToDeviceAsyncExternalCpuOnly) {
*literal));
}

TEST(PjRtClientTest, CompileUsesStableHloVersion) {
SetUpCpuPjRtApi();
TF_ASSERT_OK_AND_ASSIGN(const PJRT_Api* c_api, pjrt::PjrtApi("cpu"));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtClient> client,
GetCApiClient("cpu"));
static auto PJRT_Client_Compile_Orig = c_api->PJRT_Client_Compile;
constexpr char kProgram[] = "func.func @main() {return}";
mlir::MLIRContext context;
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(kProgram, context));
const_cast<PJRT_Api*>(c_api)->PJRT_Client_Compile =
[](PJRT_Client_Compile_Args* args) -> PJRT_Error* {
mlir::vhlo::Version version = mlir::vhlo::Version::getCurrentVersion();
std::string version_string = absl::StrFormat(
"%d.%d.%d", version.getMajor(), version.getMinor(), version.getPatch());
// MLIR doesn't have any functionality for retrieving the producer of
// bytecode files, so just scan the raw string.
EXPECT_TRUE(llvm::StringRef(args->program->code, args->program->code_size)
.contains(version_string));
return PJRT_Client_Compile_Orig(args);
};
std::unique_ptr<PjRtLoadedExecutable> executable =
client->Compile(*module, CompileOptions()).value();
const_cast<PJRT_Api*>(c_api)->PJRT_Client_Compile = PJRT_Client_Compile_Orig;
}

} // namespace
} // namespace xla

0 comments on commit c04a132

Please sign in to comment.