Skip to content

Commit 1ea12de

Browse files
Integrate LLVM at 8885b5c0626065274cb8f8a634d45779a0f6ff2b (#4089)
Update LLVM to llvm/llvm-project@8885b5c TOSA Updates Summary: 1: [TOSA] Update rescale's input_/output_unsigned attrs as required Update tosa.rescale's input_unsigned and output_unsigned attributes as required in align with TOSA v1.0 spec 2: [TOSA] Update LIT test for tosa.avg_pool2d TOSA v1.0 updates tosa.avg_pool2d's input_zp and output_zp as inputs. Update LIT tests accordingly. 3: [TOSA] Update rescale op's multiplier and shift as inputs Update tosa.rescale's multiplier and shift parameters from attributes to inputs in alignment with TOSA v1.0 spec 4: [TOSA] Update ConstShape and ConstShapeOp's `value` to `values` Update ConstShape and ConstShapeOp's `value` parameter to `values` in alignment with TOSA v1.0 spec 5: [TOSA] Update tosa.matmul zero points to inputs Update tosa.matmul's A_zp and B_zp to inputs in alignment with TOSA v1.0 --------- Signed-off-by: Vivek Khandelwal <[email protected]> Co-authored-by: Justin Ngo <[email protected]>
1 parent 711560c commit 1ea12de

File tree

6 files changed

+537
-459
lines changed

6 files changed

+537
-459
lines changed

externals/llvm-project

Submodule llvm-project updated 3017 files

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+41-11
Original file line numberDiff line numberDiff line change
@@ -1848,22 +1848,52 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
18481848
SmallVector<int64_t> matmulOutputShape(
18491849
{matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]});
18501850
Type outputElemTy;
1851-
if (isa<mlir::FloatType>(lhsElemTy)) {
1852-
outputElemTy = lhsElemTy;
1853-
} else { // qint8 emits i32 matmul output
1851+
1852+
bool isInputElemTyQInt8 = false;
1853+
if (isa<mlir::quant::UniformQuantizedType>(lhsElemTy)) {
1854+
mlir::quant::UniformQuantizedType inputQTy =
1855+
dyn_cast<mlir::quant::UniformQuantizedType>(lhsElemTy);
1856+
if (inputQTy.getStorageTypeIntegralWidth() == 8)
1857+
isInputElemTyQInt8 = true;
1858+
}
1859+
1860+
if (isInputElemTyQInt8) {
1861+
// qint8 emits i32 matmul output
18541862
outputElemTy = rewriter.getIntegerType(32);
1863+
} else {
1864+
outputElemTy = lhsElemTy;
18551865
}
18561866

18571867
auto mmOutputTy = RankedTensorType::get(
18581868
makeShapeLLVMCompatible(matmulOutputShape), outputElemTy);
1859-
auto mmOpResult =
1860-
rewriter
1861-
.create<tosa::MatMulOp>(
1862-
op->getLoc(),
1863-
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
1864-
mmOutputTy),
1865-
matmulLhs, matmulRhs)
1866-
.getResult();
1869+
1870+
Value mmOpResult;
1871+
if (!isInputElemTyQInt8) {
1872+
// LHS and RHS tensors' zero points must be zero for non-int8 types
1873+
Value lhsZp =
1874+
tosa::createZeroPointTensor(rewriter, op->getLoc(), lhsElemTy, 0)
1875+
.value();
1876+
Value rhsZp =
1877+
tosa::createZeroPointTensor(rewriter, op->getLoc(), rhsElemTy, 0)
1878+
.value();
1879+
mmOpResult =
1880+
rewriter
1881+
.create<tosa::MatMulOp>(
1882+
op->getLoc(),
1883+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
1884+
mmOutputTy),
1885+
matmulLhs, matmulRhs, lhsZp, rhsZp)
1886+
.getResult();
1887+
} else {
1888+
mmOpResult =
1889+
rewriter
1890+
.create<tosa::MatMulOp>(
1891+
op->getLoc(),
1892+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
1893+
mmOutputTy),
1894+
matmulLhs, matmulRhs)
1895+
.getResult();
1896+
}
18671897

18681898
// Perform the reshape to output shape. This is always required unless max
18691899
// input rank=3 and there was no broadcasting, in which case the tosa.matmul

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

+59-15
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,21 @@
1515
namespace mlir {
1616
namespace tosa {
1717

18+
Value buildRescaleMultiplier(bool scale32, PatternRewriter &rewriter,
19+
Operation *op, ArrayRef<int32_t> multipliers) {
20+
if (scale32) {
21+
return tosa::getConstTensor<int32_t>(
22+
rewriter, op, multipliers,
23+
{static_cast<int64_t>(multipliers.size())})
24+
.value();
25+
} else {
26+
SmallVector<int16_t> vec(multipliers.begin(), multipliers.end());
27+
return tosa::getConstTensor<int16_t>(rewriter, op, vec,
28+
{static_cast<int64_t>(vec.size())})
29+
.value();
30+
}
31+
}
32+
1833
// Create a TOSA rescale op from input framework tensor, zero points and
1934
// rounding mode
2035
Value buildRescale(PatternRewriter &rewriter, Operation *op,
@@ -28,14 +43,22 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
2843

2944
computeMultiplierAndShift(scale, multiplier, shift, scale_width);
3045

46+
Value multiplier_val =
47+
buildRescaleMultiplier(scale32, rewriter, op, {multiplier});
48+
auto shift_val = tosa::getConstTensor<int8_t>(
49+
rewriter, op, {static_cast<int8_t>(shift)}, {1})
50+
.value();
51+
52+
bool input_unsigned = input_val.getType().isUnsignedInteger();
53+
bool output_unsigned = output_type.isUnsignedInteger();
54+
3155
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
32-
rewriter, op->getLoc(), output_type, input_val,
56+
rewriter, op->getLoc(), output_type, input_val, multiplier_val, shift_val,
3357
rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
3458
rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
35-
rewriter.getDenseI32ArrayAttr({multiplier}),
36-
rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
3759
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
38-
rewriter.getBoolAttr(false));
60+
rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
61+
rewriter.getBoolAttr(output_unsigned));
3962

4063
return rescale_op.getResult();
4164
}
@@ -70,6 +93,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
7093
bool scale32 = isScale32(output_qtype);
7194
int32_t scale_width = scale32 ? 32 : 16;
7295

96+
bool input_unsigned = input_qtype.isUnsignedInteger();
97+
bool output_unsigned = output_qtype.isUnsignedInteger();
98+
7399
if (auto weight_per_tensor_qtype =
74100
dyn_cast<mlir::quant::UniformQuantizedType>(
75101
weight_type.getElementType())) {
@@ -83,13 +109,19 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
83109

84110
computeMultiplierAndShift(op_tensor_scale, multiplier, shift, scale_width);
85111

112+
Value multiplier_val =
113+
buildRescaleMultiplier(scale32, rewriter, op, {multiplier});
114+
auto shift_val = tosa::getConstTensor<int8_t>(
115+
rewriter, op, {static_cast<int8_t>(shift)}, {1})
116+
.value();
117+
86118
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
87-
rewriter, op->getLoc(), output_type, conv_val,
88-
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
89-
rewriter.getDenseI32ArrayAttr({multiplier}),
90-
rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
91-
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(true),
92-
rewriter.getBoolAttr(false));
119+
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
120+
shift_val, rewriter.getI32IntegerAttr(0),
121+
rewriter.getI32IntegerAttr(output_zp), rewriter.getBoolAttr(scale32),
122+
rewriter.getBoolAttr(true), rewriter.getBoolAttr(false),
123+
rewriter.getBoolAttr(input_unsigned),
124+
rewriter.getBoolAttr(output_unsigned));
93125

94126
return rescale_op.getResult();
95127

@@ -120,12 +152,20 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
120152
shift_arr.push_back(static_cast<int8_t>(shift));
121153
}
122154

155+
Value multiplier_val =
156+
buildRescaleMultiplier(scale32, rewriter, op, multiplier_arr);
157+
auto shift_val =
158+
tosa::getConstTensor<int8_t>(rewriter, op, shift_arr,
159+
{static_cast<int64_t>(shift_arr.size())})
160+
.value();
161+
123162
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
124-
rewriter, op->getLoc(), output_type, conv_val,
125-
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
126-
rewriter.getDenseI32ArrayAttr(multiplier_arr),
127-
rewriter.getDenseI8ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
128-
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
163+
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
164+
shift_val, rewriter.getI32IntegerAttr(0),
165+
rewriter.getI32IntegerAttr(output_zp), rewriter.getBoolAttr(scale32),
166+
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
167+
rewriter.getBoolAttr(input_unsigned),
168+
rewriter.getBoolAttr(output_unsigned));
129169

130170
return rescale_op.getResult();
131171

@@ -408,6 +448,10 @@ template std::optional<Value>
408448
getConstTensor<int8_t>(PatternRewriter &, Operation *, ArrayRef<int8_t> vec,
409449
ArrayRef<int64_t> shape, std::optional<Type> dtype);
410450

451+
template std::optional<Value>
452+
getConstTensor<int16_t>(PatternRewriter &, Operation *, ArrayRef<int16_t> vec,
453+
ArrayRef<int64_t> shape, std::optional<Type> dtype);
454+
411455
template std::optional<Value>
412456
getConstTensor<int32_t>(PatternRewriter &, Operation *, ArrayRef<int32_t> vec,
413457
ArrayRef<int64_t> shape, std::optional<Type> dtype);

projects/pt1/e2e_testing/xfail_sets.py

-2
Original file line numberDiff line numberDiff line change
@@ -3472,7 +3472,6 @@
34723472
"AtenMatmulQint8VM_basic",
34733473
"AtenMatmulQint8VV_basic",
34743474
"AtenMatmulQint8_basic",
3475-
"AtenMmIntTypes_basic",
34763475
"AtenMmQMixedSigni8_basic",
34773476
"AtenMmQint8_basic",
34783477
"AtenMmQuint8_basic",
@@ -3496,7 +3495,6 @@
34963495
"BincountMinlengthModule_basic",
34973496
"BincountModule_basic",
34983497
"BincountStaticSizeModule_basic",
3499-
"BmmIntModule_basic",
35003498
"BoolFloatConstantModule_basic",
35013499
"BoolFloatFalseModule_basic",
35023500
"BoolFloatTrueModule_basic",

0 commit comments

Comments
 (0)