Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lowering of aten.vstack #3636

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10300,6 +10300,29 @@ def Torch_AtenAtleast2dOp : Torch_Op<"aten.atleast_2d", [
}];
}

def Torch_AtenVstackOp : Torch_Op<"aten.vstack", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::vstack : (Tensor[]) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenVstackOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenVstackOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
42 changes: 42 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10568,6 +10568,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.vstack\"(%arg0: !torch.list<list<int>>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" torch.prim.Loop %1, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %3 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %4 = func.call @\"__torch_mlir_shape_fn.aten.atleast_2d\"(%3) : (!torch.list<int>) -> !torch.list<int>\n"
" %5 = torch.aten.append.t %0, %4 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %2 = call @__torch__.torch.jit._shape_functions.cat(%0, %int0) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -15070,6 +15085,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.vstack\"(%arg0: !torch.list<tuple<int, int>>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %2 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" torch.prim.Loop %4, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
Expand Down
92 changes: 92 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,97 @@ class DecomposeAtenAtleast2dOp : public OpRewritePattern<AtenAtleast2dOp> {
};
} // namespace

namespace {
// Decompose aten.vstack into: aten.vstack and aten.cat See:
// https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L3887
// @out_wrapper()
// def vstack(tensors: TensorSequenceType) -> TensorLikeType:
// torch._check(len(tensors) > 0,
// lambda: "vstack expects a non-empty TensorList")
// aligned_tensors = atleast_2d(*tensors) return
// cat(aligned_tensors, 0)
class DecomposeAtenVstackOp : public OpRewritePattern<AtenVstackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVstackOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

SmallVector<Value> inputTensors;
if (!getListConstructElements(op.getTensors(), inputTensors)) {
return rewriter.notifyMatchFailure(
op, "input should come from a PrimListConstructOp");
}
if (inputTensors.empty()) {
return rewriter.notifyMatchFailure(
op, "vstack expects a non-empty TensorList");
}

for (Value tensor : inputTensors) {
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
if (!tensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: at least one tensor does not have known sizes");
}
}

// calculate dimensions of the tensors, after they've all had atleast_2d
// applied to them.
SmallVector<SmallVector<int64_t>> tensorShapes;
for (Value tensor : inputTensors) {
auto tensorType = cast<BaseTensorType>(tensor.getType());
SmallVector<int64_t> tensorShape{tensorType.getSizes()};
switch (tensorShape.size()) {
case 0:
tensorShape = SmallVector<int64_t>{1, 1};
break;
case 1:
int64_t x = tensorShape[0];
tensorShape = SmallVector<int64_t>{1, x};
break;
}
tensorShapes.push_back(tensorShape);
}

// check if all tensors match in all dimensions except the 0th dimension
SmallVector<int64_t> fstTensorShape = tensorShapes[0];
int64_t fstTensorRank = fstTensorShape.size();
for (auto tensorShape : tensorShapes) {
if (fstTensorShape.size() != tensorShape.size() ||
!std::equal(fstTensorShape.begin() + 1,
fstTensorShape.begin() + fstTensorRank,
tensorShape.begin() + 1)) {
return rewriter.notifyMatchFailure(
op, "tensors must have all matching dimensions except for 0");
}
}

// if so, proceed, and create all the necessary ops.
SmallVector<Value> atleast2dTensorOps;
for (size_t i = 0; i < inputTensors.size(); i++) {
auto tensorType = cast<BaseTensorType>(inputTensors[i].getType());
auto newTensorType = rewriter.getType<ValueTensorType>(
tensorShapes[i], tensorType.getOptionalDtype());
auto tensorOp =
rewriter.create<AtenAtleast2dOp>(loc, newTensorType, inputTensors[i]);
atleast2dTensorOps.push_back(tensorOp);
}

auto elemType = cast<BaseTensorType>(atleast2dTensorOps[0].getType())
.getWithSizesAndDtype(std::nullopt, nullptr);
Value atleast2dTensorList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(elemType), atleast2dTensorOps);

auto zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenCatOp>(op, op.getType(),
atleast2dTensorList, zero);

return success();
}
};
} // namespace

namespace {
// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce
// operation and permute operation. Currently, this pass doesn't support
Expand Down Expand Up @@ -9477,6 +9568,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVstackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenTanhBackwardOp>();
target.addIllegalOp<AtenAtleast1dOp>();
target.addIllegalOp<AtenAtleast2dOp>();
target.addIllegalOp<AtenVstackOp>();
target.addIllegalOp<AtenEinsumOp>();
target.addIllegalOp<AtenTraceOp>();
target.addIllegalOp<AtenAddmmOp>();
Expand Down
17 changes: 17 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,11 @@
"Atleast2dModule0dInput_basic",
"Atleast2dModule1dInput_basic",
"Atleast2dModule2dInput_basic",
"AtenVstackModule_basic",
"AtenVstackStaticModule_basic",
"AtenVstackSingleElementListModule_basic",
"AtenVstackMatchingZerothDimModule_basic",
"AtenVstackPromoteDTypeModule_basic",
"AtenLinear1D_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
Expand Down Expand Up @@ -1582,6 +1587,8 @@
"Atleast2dModule0dInput_basic",
"Atleast2dModule1dInput_basic",
"Atleast2dModule2dInput_basic",
"AtenVstackStaticModule_basic",
"AtenVstackSingleElementListModule_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
Expand Down Expand Up @@ -2419,6 +2426,11 @@
"AtenSubFloatModule_basic",
"AtenTopKModule_basic",
"AtenTopKSmallestModule_basic",
"AtenVstackModule_basic",
"AtenVstackStaticModule_basic",
"AtenVstackSingleElementListModule_basic",
"AtenVstackMatchingZerothDimModule_basic",
"AtenVstackPromoteDTypeModule_basic",
"Aten_EmbeddingBagExample_basic",
"AvgPool2dWithoutPadModule_basic",
"BatchMlpLayerModule_basic",
Expand Down Expand Up @@ -4538,6 +4550,11 @@
"TensorsConcatModule_basic",
"TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"AtenVstackModule_basic",
"AtenVstackStaticModule_basic",
"AtenVstackSingleElementListModule_basic",
"AtenVstackMatchingZerothDimModule_basic",
"AtenVstackPromoteDTypeModule_basic",
"TensorsStackModule_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2130,6 +2130,10 @@ def aten〇atleast_2d〡shape(self: List[int]) -> List[int]:
else:
return self

def aten〇vstack〡shape(tensors: List[List[int]]) -> List[int]:
tensors_atleast2d = [ aten〇atleast_2d〡shape(tensor) for tensor in tensors ]
return upstream_shape_functions.cat(tensors_atleast2d, 0)

def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
return upstream_shape_functions.stack(tensors, dim)

Expand Down Expand Up @@ -5279,6 +5283,22 @@ def aten〇atleast_2d〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]),
Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]),
Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32),
NonZeroDTensorWithDtype(torch.complex64)])])
def aten〇vstack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int:
ranks: List[Optional[int]] = []
dtypes: List[int] = []
assert len(tensors_rank_dtype) != 0
for tensor_rank_dtype in tensors_rank_dtype:
tensor_rank, tensor_dtype = tensor_rank_dtype
ranks.append(tensor_rank)
dtypes.append(tensor_dtype)

return promote_dtypes(ranks, dtypes)

@check_dtype_function(
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
TensorOfShape(1, dtype=torch.int32)]),])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
emit("aten::atleast_1d : (Tensor) -> (Tensor)")
emit("aten::atleast_2d : (Tensor) -> (Tensor)")
emit("aten::vstack : (Tensor[]) -> (Tensor)")
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
emit("aten::trace : (Tensor) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
Expand Down
Loading
Loading