diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index 8250c408edf619..d58f809b1c3c41 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -742,7 +742,12 @@ MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( // This vector size is always valid: we know that the reduced dimension is a // power of 2, since otherwise RowReductionGetRowsPerWarp would have // returned 1. - int vector_size = 32 / smallest_input_or_output_bits; + // Our codegen can't currently deal with vectorization across rows, so we + // limit the vector size to the size of the row. Note that this emitter + // essentially reverts to the loop emitter in this case, except for side + // outputs. + int vector_size = std::min(static_cast(input_shape_[kRowMinorReduced]), + 32 / smallest_input_or_output_bits); // We target 8 warps per block, which means there could be up to 8 blocks per // SM, but we have no good way of knowing. In practice, enabling vectorization diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc index 8fc0ba6af05bb7..214e9b582cb123 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -161,6 +161,24 @@ constexpr auto kMultiRowReductionX2VectorX4 = R"( ROOT fusion = (pred[76800]{0}, pred[76800]{0}) fusion(p0, p1), kind=kInput, calls=fusion })"; +constexpr auto kMultiRowReductionX16VectorX2 = R"( + or { + tmp_0 = pred[] parameter(0) + tmp_1 = pred[] parameter(1) + ROOT tmp_2 = pred[] or(tmp_0, tmp_1) + } + + fusion { + p0 = pred[76800,2] parameter(0) + c0 = pred[] constant(false) + ROOT reduce = pred[76800] reduce(p0, c0), dimensions={1}, to_apply=or + } + + ENTRY main { + p0 = pred[76800,2] parameter(0) + ROOT fusion = pred[76800] fusion(p0), kind=kInput, calls=fusion + })"; + constexpr std::string_view kRowReductionSideOutput = R"( Add { lhs = f32[] parameter(0) @@ -855,6 +873,11 @@ TEST_F(MlirMultiRowReductionTest, VectorizedX4Indexing) { ElementsAre(1 /* major reduced */, 4 /* vector size */)); } +TEST_F(MlirMultiRowReductionTest, LimitedVectorizationCorrectness) { + EXPECT_TRUE( + RunAndCompareNoHloPasses(kMultiRowReductionX16VectorX2, ErrorSpec{1e-3})); +} + TEST_F(MlirMultiRowReductionTest, VectorizedX4Correctness) { EXPECT_TRUE( RunAndCompareNoHloPasses(kMultiRowReductionX2VectorX4, ErrorSpec{1e-3}));