Skip to content

Commit

Permalink
Fix vectorization of tiny multi-row reductions.
Browse files Browse the repository at this point in the history
For these we can attempt to use a vectorization factor greater
than the row length, which is not something we currently support
in codegen.

PiperOrigin-RevId: 654698797
  • Loading branch information
jreiffers authored and tensorflower-gardener committed Jul 22, 2024
1 parent b96b7e7 commit c10f5e8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
7 changes: 6 additions & 1 deletion third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,12 @@ MlirMultiRowReductionFusion::MlirMultiRowReductionFusion(
// This vector size is always valid: we know that the reduced dimension is a
// power of 2, since otherwise RowReductionGetRowsPerWarp would have
// returned 1.
int vector_size = 32 / smallest_input_or_output_bits;
// Our codegen can't currently deal with vectorization across rows, so we
// limit the vector size to the size of the row. Note that this emitter
// essentially reverts to the loop emitter in this case, except for side
// outputs.
int vector_size = std::min(static_cast<int>(input_shape_[kRowMinorReduced]),
32 / smallest_input_or_output_bits);

// We target 8 warps per block, which means there could be up to 8 blocks per
// SM, but we have no good way of knowing. In practice, enabling vectorization
Expand Down
23 changes: 23 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,24 @@ constexpr auto kMultiRowReductionX2VectorX4 = R"(
ROOT fusion = (pred[76800]{0}, pred[76800]{0}) fusion(p0, p1), kind=kInput, calls=fusion
})";

constexpr auto kMultiRowReductionX16VectorX2 = R"(
or {
tmp_0 = pred[] parameter(0)
tmp_1 = pred[] parameter(1)
ROOT tmp_2 = pred[] or(tmp_0, tmp_1)
}
fusion {
p0 = pred[76800,2] parameter(0)
c0 = pred[] constant(false)
ROOT reduce = pred[76800] reduce(p0, c0), dimensions={1}, to_apply=or
}
ENTRY main {
p0 = pred[76800,2] parameter(0)
ROOT fusion = pred[76800] fusion(p0), kind=kInput, calls=fusion
})";

constexpr std::string_view kRowReductionSideOutput = R"(
Add {
lhs = f32[] parameter(0)
Expand Down Expand Up @@ -855,6 +873,11 @@ TEST_F(MlirMultiRowReductionTest, VectorizedX4Indexing) {
ElementsAre(1 /* major reduced */, 4 /* vector size */));
}

TEST_F(MlirMultiRowReductionTest, LimitedVectorizationCorrectness) {
EXPECT_TRUE(
RunAndCompareNoHloPasses(kMultiRowReductionX16VectorX2, ErrorSpec{1e-3}));
}

TEST_F(MlirMultiRowReductionTest, VectorizedX4Correctness) {
EXPECT_TRUE(
RunAndCompareNoHloPasses(kMultiRowReductionX2VectorX4, ErrorSpec{1e-3}));
Expand Down

0 comments on commit c10f5e8

Please sign in to comment.