Skip to content

Commit

Permalink
[XLA:GPU] Log the error if parsing of Triton IR from custom call fails.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702299454
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Dec 3, 2024
1 parent 2b9dbdd commit 484e69f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 50 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ cc_library(
"//xla/ffi:ffi_api",
"//xla/ffi/api:c_api",
"//xla/hlo/ir:hlo",
"//xla/mlir/utils:error_util",
"//xla/mlir_hlo:transforms_gpu_passes",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
Expand Down
17 changes: 14 additions & 3 deletions third_party/xla/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ limitations under the License.
#include "xla/layout.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/mlir/utils/error_util.h"
#include "xla/mlir_hlo/transforms/gpu_passes.h"
#include "xla/primitive_util.h"
#include "xla/service/buffer_assignment.h"
Expand Down Expand Up @@ -1380,9 +1381,19 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
ir_emitter_context_->name_uniquer()->GetUniqueName(call.name);
VLOG(3) << "Generating: " << kernel_name;

auto triton_module =
mlir::parseSourceString<mlir::ModuleOp>(call.ir, &mlir_context);
TF_RET_CHECK(triton_module) << "Failed to parse Triton module: " << call.ir;
mlir::OwningOpRef<mlir::ModuleOp> triton_module;
{
mlir::BaseScopedDiagnosticHandler diagnostic_handler(&mlir_context);
triton_module =
mlir::parseSourceString<mlir::ModuleOp>(call.ir, &mlir_context);
if (!triton_module) {
return absl::InvalidArgumentError(
absl::StrCat("Failed to parse Triton module: ",
diagnostic_handler.ConsumeStatus().message(),
"\ninput ir: ", call.ir));
}
}

auto triton_fn =
triton_module->lookupSymbol<mlir::triton::FuncOp>(call.name);
TF_RET_CHECK(triton_fn)
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ xla_test(
"//xla/hlo/ir:hlo",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,32 +195,11 @@ TEST_F(GpuIrEmitterUnnestedTest, CanNotEmitTritonCustomCallOnPreAmpereGpu) {
"(compute capability 8.0) and up, but got")));
}

class TritonCustomCallTest : public HloTestBase {};

TEST_F(TritonCustomCallTest, NoArgumentDeduplication) {
if (auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
!cc.IsAtLeastAmpere()) {
TEST_F(GpuIrEmitterUnnestedTest, FailGracefullyIfTritonModuleIsNotParseable) {
if (!GetCudaComputeCapability().IsAtLeastAmpere()) {
GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up.";
}

// Tests that no argument deduplication is done for Triton kernels.
//
// Triton kernels are compiled on the first call and re-used for all the
// following calls. So, if we are unlucky, we could end up calling the
// compiled kernel with fewer arguments than it expects in the presence
// of argument deduplication.
//
// For example,
//
// * The first call is f(x, y). The arguments are distinct, no deduplication
// is done at compilation time and the compiled kernel expects two
// arguments.
// * The second call is f(x, x). The arguments are deduplicated and we
// call the previously compiled kernel with just x, causing a crash.

HloComputation::Builder computation_builder(TestName());

Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
Expand All @@ -232,23 +211,19 @@ TEST_F(TritonCustomCallTest, NoArgumentDeduplication) {
HloInstruction* param_1 = computation_builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "arg_1"));

auto* instr_0 = computation_builder.AddInstruction(CreateTritonCustomCall(
tuple_shape, param_0, param_1, kMLIRText, kCallName));
computation_builder.AddInstruction(CreateTritonCustomCall(
tuple_shape, instr_0, instr_0, kMLIRText, kCallName));
computation_builder.AddInstruction(
CreateTritonCustomCall(tuple_shape, param_0, param_1,
/*mlir_text=*/"unparseable_mlir", kCallName));

auto module = CreateNewVerifiedModule();
module->AddEntryComputation(computation_builder.Build());
EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false));
EXPECT_THAT(Run(std::move(module), /*run_hlo_passes=*/false).message(),
HasSubstr("Failed to parse Triton module"));
}

TEST_F(TritonCustomCallTest, FailGracefullyIfTritonModuleIsNotParseable) {
if (auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
cc.IsAtLeastAmpere()) {
GTEST_SKIP() << "Running on Ampere or more recent GPU, skipping.";
TEST_F(GpuIrEmitterUnnestedTest, FailGracefullyIfCallNameIsInvalid) {
if (!GetCudaComputeCapability().IsAtLeastAmpere()) {
GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up.";
}

HloComputation::Builder computation_builder(TestName());
Expand All @@ -263,24 +238,41 @@ TEST_F(TritonCustomCallTest, FailGracefullyIfTritonModuleIsNotParseable) {
HloInstruction::CreateParameter(1, scalar_shape, "arg_1"));

computation_builder.AddInstruction(
CreateTritonCustomCall(tuple_shape, param_0, param_1,
/*mlir_text=*/"unparseable_mlir", kCallName));
CreateTritonCustomCall(tuple_shape, param_0, param_1, kMLIRText,
/*call_name=*/"invalid_call_name"));

auto module = CreateNewVerifiedModule();
module->AddEntryComputation(computation_builder.Build());
EXPECT_THAT(Run(std::move(module), /*run_hlo_passes=*/false).message(),
HasSubstr("Failed to parse Triton module"));
HasSubstr("Call name not found in the Triton module"));
}

TEST_F(TritonCustomCallTest, FailGracefullyIfCallNameIsInvalid) {
class TritonCustomCallTest : public HloTestBase {};

TEST_F(TritonCustomCallTest, NoArgumentDeduplication) {
if (auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
cc.IsAtLeastAmpere()) {
GTEST_SKIP() << "Running on Ampere or more recent GPU, skipping.";
!cc.IsAtLeastAmpere()) {
GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up.";
}

// Tests that no argument deduplication is done for Triton kernels.
//
// Triton kernels are compiled on the first call and re-used for all the
// following calls. So, if we are unlucky, we could end up calling the
// compiled kernel with fewer arguments than it expects in the presence
// of argument deduplication.
//
// For example,
//
// * The first call is f(x, y). The arguments are distinct, no deduplication
// is done at compilation time and the compiled kernel expects two
// arguments.
// * The second call is f(x, x). The arguments are deduplicated and we
// call the previously compiled kernel with just x, causing a crash.

HloComputation::Builder computation_builder(TestName());

Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
Expand All @@ -292,14 +284,14 @@ TEST_F(TritonCustomCallTest, FailGracefullyIfCallNameIsInvalid) {
HloInstruction* param_1 = computation_builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "arg_1"));

computation_builder.AddInstruction(
CreateTritonCustomCall(tuple_shape, param_0, param_1, kMLIRText,
/*call_name=*/"invalid_call_name"));
auto* instr_0 = computation_builder.AddInstruction(CreateTritonCustomCall(
tuple_shape, param_0, param_1, kMLIRText, kCallName));
computation_builder.AddInstruction(CreateTritonCustomCall(
tuple_shape, instr_0, instr_0, kMLIRText, kCallName));

auto module = CreateNewVerifiedModule();
module->AddEntryComputation(computation_builder.Build());
EXPECT_THAT(Run(std::move(module), /*run_hlo_passes=*/false).message(),
HasSubstr("Call name not found in the Triton module"));
EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false));
}

} // namespace gpu
Expand Down

0 comments on commit 484e69f

Please sign in to comment.