15
15
namespace mlir {
16
16
namespace tosa {
17
17
18
+ Value buildRescaleMultiplier (bool scale32, PatternRewriter &rewriter,
19
+ Operation *op, ArrayRef<int32_t > multipliers) {
20
+ if (scale32) {
21
+ return tosa::getConstTensor<int32_t >(
22
+ rewriter, op, multipliers,
23
+ {static_cast <int64_t >(multipliers.size ())})
24
+ .value ();
25
+ } else {
26
+ SmallVector<int16_t > vec (multipliers.begin (), multipliers.end ());
27
+ return tosa::getConstTensor<int16_t >(rewriter, op, vec,
28
+ {static_cast <int64_t >(vec.size ())})
29
+ .value ();
30
+ }
31
+ }
32
+
18
33
// Create a TOSA rescale op from input framework tensor, zero points and
19
34
// rounding mode
20
35
Value buildRescale (PatternRewriter &rewriter, Operation *op,
@@ -28,14 +43,22 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
28
43
29
44
computeMultiplierAndShift (scale, multiplier, shift, scale_width);
30
45
46
+ Value multiplier_val =
47
+ buildRescaleMultiplier (scale32, rewriter, op, {multiplier});
48
+ auto shift_val = tosa::getConstTensor<int8_t >(
49
+ rewriter, op, {static_cast <int8_t >(shift)}, {1 })
50
+ .value ();
51
+
52
+ bool input_unsigned = input_val.getType ().isUnsignedInteger ();
53
+ bool output_unsigned = output_type.isUnsignedInteger ();
54
+
31
55
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
32
- rewriter, op->getLoc (), output_type, input_val,
56
+ rewriter, op->getLoc (), output_type, input_val, multiplier_val, shift_val,
33
57
rewriter.getI32IntegerAttr (static_cast <int32_t >(input_zp)),
34
58
rewriter.getI32IntegerAttr (static_cast <int32_t >(output_zp)),
35
- rewriter.getDenseI32ArrayAttr ({multiplier}),
36
- rewriter.getDenseI8ArrayAttr ({static_cast <int8_t >(shift)}),
37
59
rewriter.getBoolAttr (scale32), rewriter.getBoolAttr (double_round),
38
- rewriter.getBoolAttr (false ));
60
+ rewriter.getBoolAttr (false ), rewriter.getBoolAttr (input_unsigned),
61
+ rewriter.getBoolAttr (output_unsigned));
39
62
40
63
return rescale_op.getResult ();
41
64
}
@@ -70,6 +93,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
70
93
bool scale32 = isScale32 (output_qtype);
71
94
int32_t scale_width = scale32 ? 32 : 16 ;
72
95
96
+ bool input_unsigned = input_qtype.isUnsignedInteger ();
97
+ bool output_unsigned = output_qtype.isUnsignedInteger ();
98
+
73
99
if (auto weight_per_tensor_qtype =
74
100
dyn_cast<mlir::quant::UniformQuantizedType>(
75
101
weight_type.getElementType ())) {
@@ -83,13 +109,19 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
83
109
84
110
computeMultiplierAndShift (op_tensor_scale, multiplier, shift, scale_width);
85
111
112
+ Value multiplier_val =
113
+ buildRescaleMultiplier (scale32, rewriter, op, {multiplier});
114
+ auto shift_val = tosa::getConstTensor<int8_t >(
115
+ rewriter, op, {static_cast <int8_t >(shift)}, {1 })
116
+ .value ();
117
+
86
118
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
87
- rewriter, op->getLoc (), output_type, conv_val,
88
- rewriter. getI32IntegerAttr ( 0 ) , rewriter.getI32IntegerAttr (output_zp ),
89
- rewriter.getDenseI32ArrayAttr ({multiplier} ),
90
- rewriter.getDenseI8ArrayAttr ({ static_cast < int8_t >(shift)} ),
91
- rewriter.getBoolAttr (scale32), rewriter. getBoolAttr ( true ),
92
- rewriter.getBoolAttr (false ));
119
+ rewriter, op->getLoc (), output_type, conv_val, multiplier_val,
120
+ shift_val , rewriter.getI32IntegerAttr (0 ),
121
+ rewriter.getI32IntegerAttr (output_zp), rewriter. getBoolAttr (scale32 ),
122
+ rewriter.getBoolAttr ( true ), rewriter. getBoolAttr ( false ),
123
+ rewriter.getBoolAttr (input_unsigned ),
124
+ rewriter.getBoolAttr (output_unsigned ));
93
125
94
126
return rescale_op.getResult ();
95
127
@@ -120,12 +152,20 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
120
152
shift_arr.push_back (static_cast <int8_t >(shift));
121
153
}
122
154
155
+ Value multiplier_val =
156
+ buildRescaleMultiplier (scale32, rewriter, op, multiplier_arr);
157
+ auto shift_val =
158
+ tosa::getConstTensor<int8_t >(rewriter, op, shift_arr,
159
+ {static_cast <int64_t >(shift_arr.size ())})
160
+ .value ();
161
+
123
162
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
124
- rewriter, op->getLoc (), output_type, conv_val,
125
- rewriter.getI32IntegerAttr (0 ), rewriter.getI32IntegerAttr (output_zp),
126
- rewriter.getDenseI32ArrayAttr (multiplier_arr),
127
- rewriter.getDenseI8ArrayAttr (shift_arr), rewriter.getBoolAttr (scale32),
128
- rewriter.getBoolAttr (true ), rewriter.getBoolAttr (true ));
163
+ rewriter, op->getLoc (), output_type, conv_val, multiplier_val,
164
+ shift_val, rewriter.getI32IntegerAttr (0 ),
165
+ rewriter.getI32IntegerAttr (output_zp), rewriter.getBoolAttr (scale32),
166
+ rewriter.getBoolAttr (true ), rewriter.getBoolAttr (true ),
167
+ rewriter.getBoolAttr (input_unsigned),
168
+ rewriter.getBoolAttr (output_unsigned));
129
169
130
170
return rescale_op.getResult ();
131
171
@@ -408,6 +448,10 @@ template std::optional<Value>
408
448
getConstTensor<int8_t >(PatternRewriter &, Operation *, ArrayRef<int8_t > vec,
409
449
ArrayRef<int64_t > shape, std::optional<Type> dtype);
410
450
451
+ template std::optional<Value>
452
+ getConstTensor<int16_t >(PatternRewriter &, Operation *, ArrayRef<int16_t > vec,
453
+ ArrayRef<int64_t > shape, std::optional<Type> dtype);
454
+
411
455
template std::optional<Value>
412
456
getConstTensor<int32_t >(PatternRewriter &, Operation *, ArrayRef<int32_t > vec,
413
457
ArrayRef<int64_t > shape, std::optional<Type> dtype);
0 commit comments