Skip to content

Commit

Permalink
[Flow] Change the definition of "dequantization" recognizer. (iree-or…
Browse files Browse the repository at this point in the history
…g#17711)

The dequantization operation today is trying to enforce that the input
indexing map is an identity. This is overly conservative for newer
quantization schemes. This changes the logic to just look at operand
ranks to check if the operation is a dequantization operation.

---------

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: Lubo Litchev <[email protected]>
  • Loading branch information
MaheshRavishankar authored and LLITCHEV committed Jul 30, 2024
1 parent 46a582b commit 8af60a7
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -543,40 +543,40 @@ bool isDequantizationLikeOp(Operation *op) {
return false;
}

// Check that only one input has an identity map, and the rest are projected
// permutations and not full permutations
OpOperand *identityInput = nullptr;
// Check that all operands that have the highest rank have bit width
// less than the output bit-width.
DenseMap<int64_t, SmallVector<RankedTensorType>> rankBuckets;
int64_t maxRank = 0;
for (OpOperand *input : genericOp.getDpsInputOperands()) {
auto inputMap = genericOp.getMatchingIndexingMap(input);
if (inputMap.isIdentity()) {
if (identityInput) {
return false;
}
identityInput = input;
} else if (!inputMap.isProjectedPermutation(true) ||
inputMap.isPermutation()) {
return false;
auto inputType = dyn_cast<RankedTensorType>(input->get().getType());
if (!inputType) {
continue;
}
int64_t currRank = inputType.getRank();
maxRank = std::max(currRank, maxRank);
rankBuckets[currRank].push_back(inputType);
}

if (!identityInput) {
if (rankBuckets[maxRank].empty()) {
return false;
}

auto indexingMaps = genericOp.getIndexingMapsArray();
if (!indexingMaps.back().isIdentity()) {
return false;
unsigned int maxInputElementBitWidth = 0;
for (auto t : rankBuckets[maxRank]) {
Type elementType = t.getElementType();
if (!elementType.isIntOrFloat()) {
return false;
}
maxInputElementBitWidth =
std::max(maxInputElementBitWidth, elementType.getIntOrFloatBitWidth());
}

// Check that the identity input element bitwidth is smaller than the output
// element bitwidth.
Type inputElementType = getElementTypeOrSelf(identityInput->get().getType());
Type outputElementType = getElementTypeOrSelf(genericOp->getResultTypes()[0]);
if (!inputElementType.isIntOrFloat() || !outputElementType.isIntOrFloat()) {
if (!outputElementType.isIntOrFloat()) {
return false;
}
if (inputElementType.getIntOrFloatBitWidth() >=
outputElementType.getIntOrFloatBitWidth()) {
if (maxInputElementBitWidth >= outputElementType.getIntOrFloatBitWidth()) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,18 @@ util.func public @use_in_dispatch_count(%arg0: tensor<1xi32>, %arg1: tensor<1xi3

// -----

util.func public @clone_dequantization(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> {
util.func public @clone_dequantization(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32xf32>, %arg3: tensor<4096x32xf32>) -> tensor<1x1x4096xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<1x1x4096xf32>
%1 = tensor.empty() : tensor<4096x32x128xf32>
%2 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0, %arg2, %arg3 : tensor<4096x32x128xi8>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%1 : tensor<4096x32x128xf32>) {
ins(%arg0, %arg2, %arg3 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%1 : tensor<4096x32x128xf32>) {
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
%5 = arith.extui %in : i8 to i32
%6 = arith.uitofp %5 : i32 to f32
Expand All @@ -164,8 +164,8 @@ util.func public @clone_dequantization(%arg0: tensor<4096x32x128xi8>, %arg1: ten
// CHECK: util.func public @clone_dequantization
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4096x32x128xi8>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x1x32x128xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4096x32x1xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<4096x32x1xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4096x32xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<4096x32xf32>
// CHECK: %[[DISP:.+]] = flow.dispatch.region -> (tensor<1x1x4096xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<1x1x4096xf32>
Expand Down Expand Up @@ -287,3 +287,42 @@ util.func public @clone_elementwise_op_empty() -> tensor<1280xf32> {
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK: flow.return %[[GENERIC]]
// CHECK: util.return %[[RETURN]]

// -----

util.func public @clone_broadcast_dequant_op(
%arg0 : tensor<10x20xi8>, %arg1 : tensor<2x10xi32>) -> tensor<2x10xi32> {
%0 = tensor.empty() : tensor<2x10x20xi32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<10x20xi8>) outs(%0 : tensor<2x10x20xi32>) {
^bb0(%b0 : i8, %b1 : i32):
%2 = arith.extsi %b0 : i8 to i32
linalg.yield %2 : i32
} -> tensor<2x10x20xi32>
%2 = flow.dispatch.region -> (tensor<2x10xi32>) {
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%1 : tensor<2x10x20xi32>) outs(%arg1 : tensor<2x10xi32>) {
^bb0(%b0: i32, %b1 : i32) :
%4 = arith.addi %b0, %b1 : i32
linalg.yield %4 : i32
} -> tensor<2x10xi32>
flow.return %3 : tensor<2x10xi32>
}
util.return %2 : tensor<2x10xi32>
}
// CHECK-LABEL: func public @clone_broadcast_dequant_op(
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi8>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<2x10xi32>)
// CHECK: %[[RETURN:.+]] = flow.dispatch.region
// CHECK: %[[DEQUANT:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK: %[[REDUCE:.+]] = linalg.generic
// CHECK-SAME: ins(%[[DEQUANT]] :
// CHECK: flow.return %[[REDUCE]]
// CHECK: return %[[RETURN]]
Original file line number Diff line number Diff line change
Expand Up @@ -502,18 +502,18 @@ util.func public @scf_nested_dispatch(%arg0 : tensor<?xi32>) -> (tensor<?xi32>)

// -----

util.func public @no_dequantization_fusion(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> {
util.func public @no_dequantization_fusion(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32xf32>, %arg3: tensor<4096x32xf32>) -> tensor<1x1x4096xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<1x1x4096xf32>
%1 = tensor.empty() : tensor<4096x32x128xf32>
%2 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0, %arg2, %arg3 : tensor<4096x32x128xi8>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%1 : tensor<4096x32x128xf32>) {
ins(%arg0, %arg2, %arg3 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%1 : tensor<4096x32x128xf32>) {
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
%5 = arith.extui %in : i8 to i32
%6 = arith.uitofp %5 : i32 to f32
Expand All @@ -537,8 +537,8 @@ util.func public @no_dequantization_fusion(%arg0: tensor<4096x32x128xi8>, %arg1:
// CHECK: util.func public @no_dequantization_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4096x32x128xi8>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x1x32x128xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4096x32x1xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<4096x32x1xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4096x32xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<4096x32xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<1x1x4096xf32>
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<4096x32x128xf32>
Expand Down Expand Up @@ -605,3 +605,38 @@ util.func public @no_dequantization_like_fusion(%arg0: tensor<32x1x16x1x8xi16>,
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: flow.return %[[MMT4D]] :
// CHECK: util.return %[[DISP]]

// -----

util.func public @broadcasting_dequant_op(%arg0 : tensor<?x?xi8>,
%rhs : tensor<?x?x?xi32>, %init : tensor<?x?x?xi32>) -> tensor<?x?x?xi32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d1 = tensor.dim %arg0, %c0 : tensor<?x?xi8>
%d2 = tensor.dim %arg0, %c1 : tensor<?x?xi8>
%d0 = tensor.dim %rhs, %c0 : tensor<?x?x?xi32>
%empty = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xi32>
%dequant = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<?x?xi8>) outs(%empty : tensor<?x?x?xi32>) {
^bb0(%in: i8, %out: i32):
%12 = arith.extui %in : i8 to i32
linalg.yield %12 : i32
} -> tensor<?x?x?xi32>
%op = linalg.batch_matmul_transpose_b
ins(%dequant, %rhs : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
outs(%init : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
util.return %op : tensor<?x?x?xi32>
}
// CHECK-LABEL: func public @broadcasting_dequant_op(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xi8>
// CHECK-NOT: flow.dispatch.region
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK: %[[RETURN:.+]] = flow.dispatch.region
// CHECK: %[[MATMUL:.+]] = linalg.batch_matmul_transpose_b
// CHECK-SAME: ins(%[[GENERIC]],
// CHECK: flow.return %[[MATMUL]]
// CHECK: return %[[RETURN]]
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
# RUN: %PYTHON -m iree_tfl_tests.east_text_detector_test --artifacts_dir=%t
# XFAIL: *

0 comments on commit 8af60a7

Please sign in to comment.