Skip to content

Commit

Permalink
Don't skip materialization of indices for some selects.
Browse files Browse the repository at this point in the history
If the select is not really elementwise, we just materialize
the indices. This is very rare, so keeping the code reasonably
simple is more important than saving all possible materializations.

PiperOrigin-RevId: 654699120
  • Loading branch information
jreiffers authored and tensorflower-gardener committed Jul 22, 2024
1 parent c10f5e8 commit fe10728
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -788,8 +788,19 @@ absl::StatusOr<SmallVector<Value, 2>> GetOperands(
const HloInstruction* instr, ValueRange indices,
const OperandProvider& operand_provider, ImplicitLocOpBuilder& builder) {
SmallVector<Value, 2> operands;
if (HloInstruction::IsOpElementwise(instr->opcode()) ||
instr->opcode() == HloOpcode::kMap) {
bool is_elementwise = HloInstruction::IsOpElementwise(instr->opcode()) ||
instr->opcode() == HloOpcode::kMap;
if (is_elementwise && instr->shape().IsArray()) {
// Check if the instruction is really elementwise. There may be some
// broadcasting.
int64_t rank = instr->shape().rank();
is_elementwise &=
absl::c_all_of(instr->operands(), [&](const HloInstruction* operand) {
return operand->shape().rank() == rank;
});
}

if (is_elementwise) {
// Avoid materializing the input indices for elementwise ops.
for (int64_t operand_number = 0; operand_number < instr->operand_count();
++operand_number) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,25 @@ TEST_F(ElementalHloToMlirTest, Map) {
)"));
}

TEST_F(ElementalHloToMlirTest, BroadcastSelect) {
TF_EXPECT_OK(Run(R"(
ENTRY main {
p0 = pred[] parameter(0)
p1 = f32[5,7] parameter(1)
p2 = f32[5,7] parameter(2)
ROOT r = f32[5,7] select(p0, p1, p2)
})",
R"(
// CHECK: @main
// CHECK-SAME: %[[P0:.*]]: tensor<i8>
// CHECK-SAME: %[[P1:.*]]: tensor<5x7xf32>, %[[P2:.*]]: tensor<5x7xf32>
// CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}}
// CHECK-DAG: tensor.extract %[[P0]][]
// CHECK-DAG: tensor.extract %[[P1]][%[[X]], %[[Y]]]
// CHECK-DAG: tensor.extract %[[P2]][%[[X]], %[[Y]]]
)"));
}

} // namespace
} // namespace mlir_converter
} // namespace gpu
Expand Down

0 comments on commit fe10728

Please sign in to comment.