diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 2be909debb9601..f15ea2b1697611 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -343,6 +343,7 @@ cc_library( "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/hlo/ir:hlo", + "//xla/mlir/utils:error_util", "//xla/mlir_hlo:transforms_gpu_passes", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 1d0b56f96d7d11..23fd64c4eb0871 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -80,6 +80,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/mlir/utils/error_util.h" #include "xla/mlir_hlo/transforms/gpu_passes.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" @@ -1380,9 +1381,19 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( ir_emitter_context_->name_uniquer()->GetUniqueName(call.name); VLOG(3) << "Generating: " << kernel_name; - auto triton_module = - mlir::parseSourceString(call.ir, &mlir_context); - TF_RET_CHECK(triton_module) << "Failed to parse Triton module: " << call.ir; + mlir::OwningOpRef triton_module; + { + mlir::BaseScopedDiagnosticHandler diagnostic_handler(&mlir_context); + triton_module = + mlir::parseSourceString(call.ir, &mlir_context); + if (!triton_module) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse Triton module: ", + diagnostic_handler.ConsumeStatus().message(), + "\ninput ir: ", call.ir)); + } + } + auto triton_fn = triton_module->lookupSymbol(call.name); TF_RET_CHECK(triton_fn) diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index b5d467f3a64237..5c47eb814cb745 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -328,7 +328,6 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc index 8572a345815bef..73dfe14387e2ba 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_triton_custom_call_test.cc @@ -195,32 +195,11 @@ TEST_F(GpuIrEmitterUnnestedTest, CanNotEmitTritonCustomCallOnPreAmpereGpu) { "(compute capability 8.0) and up, but got"))); } -class TritonCustomCallTest : public HloTestBase {}; - -TEST_F(TritonCustomCallTest, NoArgumentDeduplication) { - if (auto cc = backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - !cc.IsAtLeastAmpere()) { +TEST_F(GpuIrEmitterUnnestedTest, FailGracefullyIfTritonModuleIsNotParseable) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up."; } - // Tests that no argument deduplication is done for Triton kernels. - // - // Triton kernels are compiled on the first call and re-used for all the - // following calls. So, if we are unlucky, we could end up calling the - // compiled kernel with fewer arguments than it expects in the presence - // of argument deduplication. - // - // For example, - // - // * The first call is f(x, y). The arguments are distinct, no deduplication - // is done at compilation time and the compiled kernel expects two - // arguments. - // * The second call is f(x, x). The arguments are deduplicated and we - // call the previously compiled kernel with just x, causing a crash. - HloComputation::Builder computation_builder(TestName()); Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); @@ -232,23 +211,19 @@ TEST_F(TritonCustomCallTest, NoArgumentDeduplication) { HloInstruction* param_1 = computation_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); - auto* instr_0 = computation_builder.AddInstruction(CreateTritonCustomCall( - tuple_shape, param_0, param_1, kMLIRText, kCallName)); - computation_builder.AddInstruction(CreateTritonCustomCall( - tuple_shape, instr_0, instr_0, kMLIRText, kCallName)); + computation_builder.AddInstruction( + CreateTritonCustomCall(tuple_shape, param_0, param_1, + /*mlir_text=*/"unparseable_mlir", kCallName)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(computation_builder.Build()); - EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); + EXPECT_THAT(Run(std::move(module), /*run_hlo_passes=*/false).message(), + HasSubstr("Failed to parse Triton module")); } -TEST_F(TritonCustomCallTest, FailGracefullyIfTritonModuleIsNotParseable) { - if (auto cc = backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - cc.IsAtLeastAmpere()) { - GTEST_SKIP() << "Running on Ampere or more recent GPU, skipping."; +TEST_F(GpuIrEmitterUnnestedTest, FailGracefullyIfCallNameIsInvalid) { + if (!GetCudaComputeCapability().IsAtLeastAmpere()) { + GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up."; } HloComputation::Builder computation_builder(TestName()); @@ -263,24 +238,41 @@ TEST_F(TritonCustomCallTest, FailGracefullyIfTritonModuleIsNotParseable) { HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); computation_builder.AddInstruction( - CreateTritonCustomCall(tuple_shape, param_0, param_1, - /*mlir_text=*/"unparseable_mlir", kCallName)); + CreateTritonCustomCall(tuple_shape, param_0, param_1, kMLIRText, + /*call_name=*/"invalid_call_name")); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(computation_builder.Build()); EXPECT_THAT(Run(std::move(module), /*run_hlo_passes=*/false).message(), - HasSubstr("Failed to parse Triton module")); + HasSubstr("Call name not found in the Triton module")); } -TEST_F(TritonCustomCallTest, FailGracefullyIfCallNameIsInvalid) { +class TritonCustomCallTest : public HloTestBase {}; + +TEST_F(TritonCustomCallTest, NoArgumentDeduplication) { if (auto cc = backend() .default_stream_executor() ->GetDeviceDescription() .cuda_compute_capability(); - cc.IsAtLeastAmpere()) { - GTEST_SKIP() << "Running on Ampere or more recent GPU, skipping."; + !cc.IsAtLeastAmpere()) { + GTEST_SKIP() << "Triton support is only enabled for Ampere GPUs and up."; } + // Tests that no argument deduplication is done for Triton kernels. + // + // Triton kernels are compiled on the first call and re-used for all the + // following calls. So, if we are unlucky, we could end up calling the + // compiled kernel with fewer arguments than it expects in the presence + // of argument deduplication. + // + // For example, + // + // * The first call is f(x, y). The arguments are distinct, no deduplication + // is done at compilation time and the compiled kernel expects two + // arguments. + // * The second call is f(x, x). The arguments are deduplicated and we + // call the previously compiled kernel with just x, causing a crash. + HloComputation::Builder computation_builder(TestName()); Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); @@ -292,14 +284,14 @@ TEST_F(TritonCustomCallTest, FailGracefullyIfCallNameIsInvalid) { HloInstruction* param_1 = computation_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape, "arg_1")); - computation_builder.AddInstruction( - CreateTritonCustomCall(tuple_shape, param_0, param_1, kMLIRText, - /*call_name=*/"invalid_call_name")); + auto* instr_0 = computation_builder.AddInstruction(CreateTritonCustomCall( + tuple_shape, param_0, param_1, kMLIRText, kCallName)); + computation_builder.AddInstruction(CreateTritonCustomCall( + tuple_shape, instr_0, instr_0, kMLIRText, kCallName)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(computation_builder.Build()); - EXPECT_THAT(Run(std::move(module), /*run_hlo_passes=*/false).message(), - HasSubstr("Call name not found in the Triton module")); + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); } } // namespace gpu