@@ -151,6 +151,22 @@ RnnLayerOutput rnn_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
151
151
output.Y_h = loop.getResult (1 );
152
152
return output;
153
153
}
154
+
155
+ static Value StaticTranspose (ImplicitLocOpBuilder b, Value value, int64_t dim0,
156
+ int64_t dim1) {
157
+ auto valueTy = cast<ValueTensorType>(value.getType ());
158
+
159
+ SmallVector<int64_t > valueShape (valueTy.getSizes ());
160
+ std::swap (valueShape[dim0], valueShape[dim1]);
161
+ valueTy = b.getType <ValueTensorType>(valueShape, valueTy.getDtype ());
162
+
163
+ auto intType = b.getType <IntType>();
164
+ Value dim0v = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (dim0));
165
+ Value dim1v = b.create <ConstantIntOp>(intType, b.getI64IntegerAttr (dim1));
166
+
167
+ return b.create <AtenTransposeIntOp>(valueTy, value, dim0v, dim1v);
168
+ }
169
+
154
170
LogicalResult OnnxRnnExpander (OpBinder binder,
155
171
ConversionPatternRewriter &rewriter) {
156
172
Location loc = binder.getLoc ();
@@ -201,9 +217,19 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
201
217
return rewriter.notifyMatchFailure (
202
218
binder.op , " Missing required attribute hidden_size" );
203
219
220
+ // Other attributes
221
+ int64_t layout;
222
+ if (binder.s64IntegerAttr (layout, " layout" , 0 ))
223
+ return rewriter.notifyMatchFailure (binder.op ,
224
+ " Unsupported layout attribute type." );
225
+
226
+ if (layout < 0 || layout > 1 )
227
+ return rewriter.notifyMatchFailure (binder.op ,
228
+ " Unsupported layout attribute value." );
229
+
204
230
// Result types
205
231
ValueTensorType yTy, Y_hType;
206
- if (binder.tensorResultTypeAtIndex (yTy, 0 ) ||
232
+ if (binder.tensorResultTypeAtIndex (yTy, 0 ) &&
207
233
binder.tensorResultTypeAtIndex (Y_hType, 1 )) {
208
234
return rewriter.notifyMatchFailure (binder.op ,
209
235
" At least one output must be present" );
@@ -229,6 +255,12 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
229
255
initial_h = nullptr ;
230
256
}
231
257
258
+ if (layout == 1 ) {
259
+ X = StaticTranspose (b, X, 0 , 1 );
260
+ if (initial_h)
261
+ initial_h = StaticTranspose (b, initial_h, 0 , 1 );
262
+ }
263
+
232
264
// validation
233
265
auto xTy = cast<ValueTensorType>(X.getType ());
234
266
auto wTy = cast<ValueTensorType>(W.getType ());
@@ -238,6 +270,7 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
238
270
auto rShape = rTy.getSizes ();
239
271
assert (wShape.size () == 3 );
240
272
273
+ int64_t seq_len = xShape[0 ];
241
274
int64_t batch_size = xShape[1 ];
242
275
int64_t x_input_size = xShape[2 ];
243
276
@@ -368,7 +401,24 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
368
401
Value Y_h_unsqueezed = b.create <AtenUnsqueezeOp>(Y_h_unsqueezed_type,
369
402
rnnLayerOutput.Y_h , cstZero);
370
403
371
- Value Y_unsqueezed = b.create <AtenUnsqueezeOp>(yTy, rnnLayerOutput.Y , cstOne);
404
+ auto Y_unsqueezed_type = b.getType <ValueTensorType>(
405
+ llvm::SmallVector<int64_t >{seq_len, num_directions, batch_size,
406
+ hidden_size},
407
+ cast<ValueTensorType>(rnnLayerOutput.Y_h .getType ()).getDtype ());
408
+ Value Y_unsqueezed =
409
+ b.create <AtenUnsqueezeOp>(Y_unsqueezed_type, rnnLayerOutput.Y , cstOne);
410
+
411
+ if (layout == 1 ) {
412
+ Y_h_unsqueezed = StaticTranspose (b, Y_h_unsqueezed, 0 , 1 );
413
+ Y_unsqueezed = StaticTranspose (b, Y_unsqueezed, 1 , 2 );
414
+ Y_unsqueezed = StaticTranspose (b, Y_unsqueezed, 0 , 1 );
415
+ }
416
+
417
+ if (!yTy)
418
+ Y_unsqueezed = cstNone;
419
+ if (!Y_hType)
420
+ Y_h_unsqueezed = cstNone;
421
+
372
422
rewriter.replaceOp (binder.op , {Y_unsqueezed, Y_h_unsqueezed});
373
423
return success ();
374
424
}
0 commit comments