Skip to content

Commit edf725e

Browse files
authored
[Torch] add AtenAsStridedOp in torch dialect (llvm#3706)
1 parent 3f07077 commit edf725e

File tree

5 files changed

+46
-5
lines changed

5 files changed

+46
-5
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+25
Original file line numberDiff line numberDiff line change
@@ -13195,6 +13195,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
1319513195
}];
1319613196
}
1319713197

13198+
def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
13199+
AllowsTypeRefinement,
13200+
ReadOnly
13201+
]> {
13202+
let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`";
13203+
let arguments = (ins
13204+
AnyTorchTensorType:$self,
13205+
AnyTorchListOfTorchIntType:$size,
13206+
AnyTorchListOfTorchIntType:$stride,
13207+
AnyTorchOptionalIntType:$storage_offset
13208+
);
13209+
let results = (outs
13210+
AnyTorchOptionalTensorType:$result
13211+
);
13212+
let hasCustomAssemblyFormat = 1;
13213+
let extraClassDefinition = [{
13214+
ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) {
13215+
return parseDefaultTorchOp(parser, result, 4, 1);
13216+
}
13217+
void AtenAsStridedOp::print(OpAsmPrinter &printer) {
13218+
printDefaultTorchOp(printer, *this, 4, 1);
13219+
}
13220+
}];
13221+
}
13222+
1319813223
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
1319913224
AllowsTypeRefinement,
1320013225
ReadOnly

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -10002,6 +10002,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1000210002
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
1000310003
" return %0 : !torch.list<int>\n"
1000410004
" }\n"
10005+
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
10006+
" return %arg1 : !torch.list<int>\n"
10007+
" }\n"
1000510008
" func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
1000610009
" %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
1000710010
" return %0 : !torch.tuple<list<int>, list<int>>\n"
@@ -12297,6 +12300,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1229712300
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1229812301
" return %0#1 : !torch.int\n"
1229912302
" }\n"
12303+
" func.func @\"__torch_mlir_dtype_fn.aten.as_strided\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.int {\n"
12304+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12305+
" return %0#1 : !torch.int\n"
12306+
" }\n"
1230012307
" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
1230112308
" return %arg3 : !torch.int\n"
1230212309
" }\n"

lib/Dialect/Torch/Utils/Utils.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,12 @@ bool Torch::isViewLikeOp(Operation *op) {
247247
// correct. We could potentially be more precise and identify the cases
248248
// that it does not return a view and treat those as having value
249249
// semantics.
250-
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
251-
AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
252-
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
253-
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
254-
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
250+
return isa<AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
251+
AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp,
252+
AtenUnflattenIntOp, AtenPermuteOp, AtenReshapeOp,
253+
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
254+
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
255+
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
255256
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
256257
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
257258
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

+7
Original file line numberDiff line numberDiff line change
@@ -1849,6 +1849,9 @@ def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int =
18491849
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
18501850
return upstream_shape_functions.slice(self, dim, start, end, step)
18511851

1852+
def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
1853+
return size
1854+
18521855
def aten〇sort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> Tuple[List[int], List[int]]:
18531856
return self, self
18541857

@@ -3377,6 +3380,10 @@ def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0
33773380
self_rank, self_dtype = self_rank_dtype
33783381
return self_dtype
33793382

3383+
def aten〇as_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> int:
3384+
self_rank, self_dtype = self_rank_dtype
3385+
return self_dtype
3386+
33803387
@check_dtype_function(
33813388
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) +
33823389
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) +

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,7 @@ def emit_with_mutating_variants(key, **kwargs):
968968
emit("aten::alias_copy : (Tensor) -> (Tensor)")
969969
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
970970
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
971+
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
971972
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
972973
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
973974
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")

0 commit comments

Comments
 (0)