Skip to content

Commit 2511cf4

Browse files
authored
[onnx] Fix onnx.RNN for layout attribute (llvm#3620)
The `layout` attribute was not considered for the `onnx.RNN` operation. Added support for the attribute to transpose the inputs / outputs of the RNN when valid.
1 parent af67f9e commit 2511cf4

File tree

1 file changed

+52
-2
lines changed

1 file changed

+52
-2
lines changed

lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp

+52-2
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,22 @@ RnnLayerOutput rnn_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h,
151151
output.Y_h = loop.getResult(1);
152152
return output;
153153
}
154+
155+
static Value StaticTranspose(ImplicitLocOpBuilder b, Value value, int64_t dim0,
156+
int64_t dim1) {
157+
auto valueTy = cast<ValueTensorType>(value.getType());
158+
159+
SmallVector<int64_t> valueShape(valueTy.getSizes());
160+
std::swap(valueShape[dim0], valueShape[dim1]);
161+
valueTy = b.getType<ValueTensorType>(valueShape, valueTy.getDtype());
162+
163+
auto intType = b.getType<IntType>();
164+
Value dim0v = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(dim0));
165+
Value dim1v = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(dim1));
166+
167+
return b.create<AtenTransposeIntOp>(valueTy, value, dim0v, dim1v);
168+
}
169+
154170
LogicalResult OnnxRnnExpander(OpBinder binder,
155171
ConversionPatternRewriter &rewriter) {
156172
Location loc = binder.getLoc();
@@ -201,9 +217,19 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
201217
return rewriter.notifyMatchFailure(
202218
binder.op, "Missing required attribute hidden_size");
203219

220+
// Other attributes
221+
int64_t layout;
222+
if (binder.s64IntegerAttr(layout, "layout", 0))
223+
return rewriter.notifyMatchFailure(binder.op,
224+
"Unsupported layout attribute type.");
225+
226+
if (layout < 0 || layout > 1)
227+
return rewriter.notifyMatchFailure(binder.op,
228+
"Unsupported layout attribute value.");
229+
204230
// Result types
205231
ValueTensorType yTy, Y_hType;
206-
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
232+
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
207233
binder.tensorResultTypeAtIndex(Y_hType, 1)) {
208234
return rewriter.notifyMatchFailure(binder.op,
209235
"At least one output must be present");
@@ -229,6 +255,12 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
229255
initial_h = nullptr;
230256
}
231257

258+
if (layout == 1) {
259+
X = StaticTranspose(b, X, 0, 1);
260+
if (initial_h)
261+
initial_h = StaticTranspose(b, initial_h, 0, 1);
262+
}
263+
232264
// validation
233265
auto xTy = cast<ValueTensorType>(X.getType());
234266
auto wTy = cast<ValueTensorType>(W.getType());
@@ -238,6 +270,7 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
238270
auto rShape = rTy.getSizes();
239271
assert(wShape.size() == 3);
240272

273+
int64_t seq_len = xShape[0];
241274
int64_t batch_size = xShape[1];
242275
int64_t x_input_size = xShape[2];
243276

@@ -368,7 +401,24 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
368401
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(Y_h_unsqueezed_type,
369402
rnnLayerOutput.Y_h, cstZero);
370403

371-
Value Y_unsqueezed = b.create<AtenUnsqueezeOp>(yTy, rnnLayerOutput.Y, cstOne);
404+
auto Y_unsqueezed_type = b.getType<ValueTensorType>(
405+
llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
406+
hidden_size},
407+
cast<ValueTensorType>(rnnLayerOutput.Y_h.getType()).getDtype());
408+
Value Y_unsqueezed =
409+
b.create<AtenUnsqueezeOp>(Y_unsqueezed_type, rnnLayerOutput.Y, cstOne);
410+
411+
if (layout == 1) {
412+
Y_h_unsqueezed = StaticTranspose(b, Y_h_unsqueezed, 0, 1);
413+
Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 1, 2);
414+
Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 0, 1);
415+
}
416+
417+
if (!yTy)
418+
Y_unsqueezed = cstNone;
419+
if (!Y_hType)
420+
Y_h_unsqueezed = cstNone;
421+
372422
rewriter.replaceOp(binder.op, {Y_unsqueezed, Y_h_unsqueezed});
373423
return success();
374424
}

0 commit comments

Comments
 (0)