From d8c8f29d7e1c7ef7b5a0c18c51ca474937c23ee3 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Fri, 25 Oct 2024 13:30:12 -0400 Subject: [PATCH] Fix mask analysis for when the entire tensor is masked off (#186) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The current formula for computing masks does not work when the mask bound is smaller than the start of the mask range: ``` ---|-------|-----------|    ^       ^           ^ bound   start       end ``` Current formula: ``` new_end = min(end, bound) new_dim = new_end - start ``` For the above case, this formula will produce a negative `new_dim`. To fix this issue, we optionally move `new_end` back to `start` so that when `bound < start`, `new_dim` is 0. The new formula is: ``` new_end_tmp = min(end, bound) new_end = max(new_end_tmp, start) new_dim = new_end - start ``` Another formula that could work in theory is to do: ``` new_end = min(end, bound) new_dim_potentially_neg = new_end - start new_dim = max(new_dim_potentially_neg, 0) ``` But this approach does not work in MaskAnalysis because we operate on the `index` type which is unsigned. We would have a negative overflow when computing `new_dim_potentially_neg` and end up getting a positive number instead. # Changes + Update the formula + The change is quite invasive, so I added a flag in cases we don't want to enable this fix + Update lit tests + Removed some of the old TritonToLinalg tests; we will remove the old pass in a future PR --- include/triton-shared/Analysis/MaskAnalysis.h | 3 + .../Analysis/OpFoldResultUtils.h | 2 + .../AnalysisStructured/PtrAnalysis.h | 6 +- .../Conversion/TritonToStructured/Passes.td | 4 +- lib/Analysis/MaskAnalysis.cpp | 16 +- lib/Analysis/OpFoldResultUtils.cpp | 28 ++ lib/AnalysisStructured/PtrAnalysis.cpp | 16 +- .../TritonToStructuredPass.cpp | 2 +- python/examples/test_mask.py | 39 ++ python/examples/test_tensor_index_iterargs.py | 35 -- .../convert_extern_elementwise.mlir | 454 +++++++++--------- .../kernel-01-vector-add.mlir | 25 +- .../kernel-02-fused-softmax.mlir | 73 +-- .../kernel-03-matrix-multiplication.mlir | 48 +- .../kernel-05-layer-norm-dwdb.mlir | 87 ++-- .../kernel-05-layer-norm-fwd.mlir | 127 ++--- .../StructuredToMemref/masked_ldst_1d.mlir | 23 +- .../StructuredToMemref/masked_ldst_2d.mlir | 38 +- .../masked_ldst_sitofp_other.mlir | 19 +- .../TritonToLinalg/kernel-01-vector-add.mlir | 77 --- .../kernel-02-fused-softmax.mlir | 105 ---- .../kernel-03-matrix-multiplication.mlir | 217 --------- .../kernel-05-layer-norm-dwdb.mlir | 189 -------- .../kernel-05-layer-norm-fwd.mlir | 313 ------------ .../TritonToLinalg/masked_ldst_1d.mlir | 45 -- .../TritonToLinalg/masked_ldst_2d.mlir | 108 ----- .../masked_ldst_sitofp_other.mlir | 47 -- ...ensor_indices_loop_iterarg_with_masks.mlir | 79 --- .../kernel-01-vector-add.mlir | 37 +- .../kernel-02-fused-softmax.mlir | 47 +- .../kernel-03-matrix-multiplication.mlir | 40 +- .../kernel-05-layer-norm-dwdb.mlir | 114 ++--- .../kernel-05-layer-norm-fwd.mlir | 124 ++--- .../TritonToStructured/masked_ldst_1d.mlir | 11 +- .../TritonToStructured/masked_ldst_2d.mlir | 46 +- .../masked_ldst_sitofp_other.mlir | 11 +- .../sign_extend_i32_to_i64.mlir | 19 +- ...ensor_indices_loop_iterarg_with_masks.mlir | 19 +- 38 files changed, 819 insertions(+), 1874 deletions(-) create mode 100644 python/examples/test_mask.py delete mode 100644 test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir delete mode 100644 test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir delete mode 100644 test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir delete mode 100644 test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir delete mode 100644 test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir delete mode 100644 test/Conversion/TritonToLinalg/masked_ldst_1d.mlir delete mode 100644 test/Conversion/TritonToLinalg/masked_ldst_2d.mlir delete mode 100644 test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir delete mode 100644 test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir diff --git a/include/triton-shared/Analysis/MaskAnalysis.h b/include/triton-shared/Analysis/MaskAnalysis.h index 6d67112f..a86bfc19 100644 --- a/include/triton-shared/Analysis/MaskAnalysis.h +++ b/include/triton-shared/Analysis/MaskAnalysis.h @@ -49,6 +49,9 @@ struct MaskState { OpFoldResult end; SmallVector dims; OpFoldResult scalar; + const bool useUnsafeMask; + + MaskState(bool useUnsafeMask = false) : useUnsafeMask(useUnsafeMask) {} int64_t getRank() const { return dims.size(); } diff --git a/include/triton-shared/Analysis/OpFoldResultUtils.h b/include/triton-shared/Analysis/OpFoldResultUtils.h index 148c52c4..bcc8287e 100644 --- a/include/triton-shared/Analysis/OpFoldResultUtils.h +++ b/include/triton-shared/Analysis/OpFoldResultUtils.h @@ -55,6 +55,8 @@ OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs, OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, const Location loc, OpBuilder &b); +OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b); } // namespace mlir #endif diff --git a/include/triton-shared/AnalysisStructured/PtrAnalysis.h b/include/triton-shared/AnalysisStructured/PtrAnalysis.h index d104cd77..3e8b6b60 100644 --- a/include/triton-shared/AnalysisStructured/PtrAnalysis.h +++ b/include/triton-shared/AnalysisStructured/PtrAnalysis.h @@ -258,11 +258,11 @@ class PtrAnalysis { // strides, offsets, and modulos. LogicalResult rewriteForOp(scf::ForOp op); - LogicalResult rewriteLoadOp(triton::LoadOp op); + LogicalResult rewriteLoadOp(triton::LoadOp op, bool useUnsafeMask = false); - LogicalResult rewriteStoreOp(triton::StoreOp op); + LogicalResult rewriteStoreOp(triton::StoreOp op, bool useUnsafeMask = false); - LogicalResult rewriteOp(Operation *op); + LogicalResult rewriteOp(Operation *op, bool useUnsafeMask = false); }; } // namespace tts diff --git a/include/triton-shared/Conversion/TritonToStructured/Passes.td b/include/triton-shared/Conversion/TritonToStructured/Passes.td index 89488d64..e2702464 100644 --- a/include/triton-shared/Conversion/TritonToStructured/Passes.td +++ b/include/triton-shared/Conversion/TritonToStructured/Passes.td @@ -10,7 +10,9 @@ def TritonToStructured : Pass<"triton-to-structured", "mlir::ModuleOp"> { Option<"runPrepassOnly", "run-prepass-only", "bool", /*default*/"false", "Only run the pre-processing pass which inserts tts.get_structured_state ops used in scf.for">, Option<"skipPrepass", "skip-prepass", "bool", /*default*/"false", - "Skip the prepass"> + "Skip the prepass">, + Option<"useUnsafeMask", "use-unsafe-mask", "bool", /*default*/"false", + "Assume that the mask bounds are never less than starting offsets. May produce incorrect results."> ]; } diff --git a/lib/Analysis/MaskAnalysis.cpp b/lib/Analysis/MaskAnalysis.cpp index 5dd4bc74..31891190 100644 --- a/lib/Analysis/MaskAnalysis.cpp +++ b/lib/Analysis/MaskAnalysis.cpp @@ -7,8 +7,6 @@ #include "triton-shared/Analysis/MaskAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Support/LogicalResult.h" @@ -341,7 +339,21 @@ LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc, assert(cmpDim != -1 && "Unexpected case where no dimension has size larger than 1"); + // Important: + // In the case where the values we are loading are entirely masked off like + // the following: + // + // ---|-------|-----------| + // ^ ^ ^ + // scalar start end + // + // newEnd = min(end, scalar) = scalar + // Now scalar < start, so simply doing dim = newEnd - start is incorrect. + // + // The correct formula is to optionally move `newDim` back to `start` using + // max(newEnd, start). auto newEnd = minOFRs(lhsState.end, rhsState.scalar, loc, builder); + newEnd = maxOFRs(newEnd, lhsState.start, loc, builder); auto newDim = subOFRs(newEnd, lhsState.start, loc, builder); for (int32_t i = 0; i < lhsState.getRank(); i++) { diff --git a/lib/Analysis/OpFoldResultUtils.cpp b/lib/Analysis/OpFoldResultUtils.cpp index 66584008..b0757b77 100644 --- a/lib/Analysis/OpFoldResultUtils.cpp +++ b/lib/Analysis/OpFoldResultUtils.cpp @@ -217,4 +217,32 @@ OpFoldResult minOFRs(const OpFoldResult lhs, const OpFoldResult rhs, return minOp.getResult(); } +OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs, + const Location loc, OpBuilder &b) { + auto lhsIntAttr = getIntAttr(lhs); + auto rhsIntAttr = getIntAttr(rhs); + + // both lhs and rhs are constants, return result directly + if (lhsIntAttr && rhsIntAttr) + return b.getIndexAttr(std::max(lhsIntAttr.value(), rhsIntAttr.value())); + + // otherwise, need to create instructions to calculate new attribute value + auto lhsValue = dyn_cast(lhs); + if (lhsIntAttr) { + auto lhsOp = + b.create(loc, b.getIndexAttr(lhsIntAttr.value())); + lhsValue = lhsOp.getResult(); + } + + auto rhsValue = dyn_cast(rhs); + if (rhsIntAttr) { + auto rhsOp = + b.create(loc, b.getIndexAttr(rhsIntAttr.value())); + rhsValue = rhsOp.getResult(); + } + + auto maxOp = b.create(loc, lhsValue, rhsValue); + return maxOp.getResult(); +} + } // namespace mlir diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index e3a2deea..eaaa9874 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -1090,7 +1090,8 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { return success(); } -LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { +LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op, + bool useUnsafeMask) { auto ptr = ptrMap.lookupOrNull(op.getPtr()); auto mask = op.getMask(); auto other = op.getOther(); @@ -1109,7 +1110,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { } ArrayRef dims; - mlir::triton::MaskState mstate; + mlir::triton::MaskState mstate(useUnsafeMask); Value scalarOther; OpBuilder builder(op); @@ -1226,7 +1227,8 @@ void PtrAnalysis::initializeMaybeStructuredArgs(Operation *op) { } } -LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { +LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op, + bool useUnsafeMask) { auto ptr = ptrMap.lookupOrNull(op.getPtr()); auto val = op.getValue(); auto mask = op.getMask(); @@ -1245,7 +1247,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { } ArrayRef dims; - mlir::triton::MaskState mstate; + mlir::triton::MaskState mstate(useUnsafeMask); OpBuilder builder(op); @@ -1270,7 +1272,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { return success(); } -LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp) { +LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) { LLVM_DEBUG({ llvm::dbgs() << "rewriting rootOp\n"; rootOp->dump(); @@ -1301,14 +1303,14 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp) { return WalkResult::advance(); }) .Case([&](auto load) { - if (rewriteLoadOp(load).failed()) { + if (rewriteLoadOp(load, useUnsafeMask).failed()) { load->emitRemark("PtrAnalysis: Failed to rewrite LoadOp"); return WalkResult::advance(); } return WalkResult::skip(); }) .Case([&](auto store) { - if (rewriteStoreOp(store).failed()) { + if (rewriteStoreOp(store, useUnsafeMask).failed()) { store->emitRemark("PtrAnalysis: Failed to rewrite StoreOp"); return WalkResult::advance(); } diff --git a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index 45ff8ab2..bcfea253 100644 --- a/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -319,7 +319,7 @@ class TritonToStructuredPass mlir::tts::PtrAnalysis ptrAnalysis; ptrAnalysis.initializeMaybeStructuredArgs(moduleOp); - if (failed(ptrAnalysis.rewriteOp(moduleOp))) { + if (failed(ptrAnalysis.rewriteOp(moduleOp, useUnsafeMask))) { moduleOp->emitWarning("PtrAnalysis failed"); } diff --git a/python/examples/test_mask.py b/python/examples/test_mask.py new file mode 100644 index 00000000..499dfa1d --- /dev/null +++ b/python/examples/test_mask.py @@ -0,0 +1,39 @@ +import torch + +import triton +import triton.language as tl + +from triton.backends.triton_shared.driver import CPUDriver + + +def test_mask(device): + @triton.jit + def test(in0, out0): + offs = 100 + tl.arange(0, 4) + out_offs = tl.arange(0, 4) + a = tl.load(in0 + offs, mask=offs < 4, other=-1) + tl.store(out0 + out_offs, a) + + SIZE = 8 + input = torch.arange(0, SIZE, device=device, dtype=torch.int32) + output = torch.full((SIZE,), -2, device=device, dtype=torch.int32) + + if device == 'cpu': + triton.runtime.driver.set_active(CPUDriver()) + + grid = lambda meta: (1,) + + src = triton.compiler.ASTSource( + fn=test, + signature="*fp32,*fp32,i32", + ) + ret = triton.compile( + src, + ) + print(ret.asm["ttir"]) + + print(output) + test[grid](input, output) + print(input) + print(output) + torch.testing.assert_close(output, torch.tensor([-1, -1, -1, -1, -2, -2, -2, -2], device=device, dtype=torch.int32)) diff --git a/python/examples/test_tensor_index_iterargs.py b/python/examples/test_tensor_index_iterargs.py index 6bc52d66..3f5878f0 100644 --- a/python/examples/test_tensor_index_iterargs.py +++ b/python/examples/test_tensor_index_iterargs.py @@ -112,38 +112,3 @@ def test_1(out0): src, ) print(ret.asm["ttir"]) - - - -def disabled_test_mask(device): - # TODO: This fails to compile in StructuredToMemref - @triton.jit - def test_1(in0, out0, batch): - offs = 4 + tl.arange(0, 4) - out_offs = tl.arange(0, 4) - a = tl.load(in0 + offs, mask=offs < 0, other=-1) - tl.store(out0 + out_offs, a) - - # TODO: This segfauls in the CPU backend - # Crashes when the batch value will mask off all of the tensors - @triton.jit - def test_2(in0, out0, batch): - offs = 4 + tl.arange(0, 4) - out_offs = tl.arange(0, 4) - a = tl.load(in0 + offs, mask=offs < 0, other=-1) - tl.store(out0 + out_offs, a) - - - SIZE = 8 - input = torch.arange(0, SIZE, device=device, dtype=torch.int32) - output = torch.full((SIZE,), -1, device=device, dtype=torch.int32) - - if device == 'cpu': - triton.runtime.driver.set_active(CPUDriver()) - - grid = lambda meta: (1,) - - print(output) - test_1[grid](input, output, 0) - print(input) - print(output) diff --git a/test/Conversion/StructuredToMemref/convert_extern_elementwise.mlir b/test/Conversion/StructuredToMemref/convert_extern_elementwise.mlir index 334bc0e2..c017ec74 100644 --- a/test/Conversion/StructuredToMemref/convert_extern_elementwise.mlir +++ b/test/Conversion/StructuredToMemref/convert_extern_elementwise.mlir @@ -636,27 +636,28 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK-DAG: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_3_]], [[VAR_subview_4_]] : memref> to memref> -// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]], [[VAR_7_]] : tensor<32xf32>, tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_8_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_8_]] : tensor<32xf32>, tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32, [[IN_2_:%.+]]: f32): -// CHECK: [[VAR_9_:%.+]] = math.atan2 [[IN_0_]], [[IN_1_]] : f32 -// CHECK: linalg.yield [[VAR_9_]] : f32 +// CHECK: [[VAR_10_:%.+]] = math.atan2 [[IN_0_]], [[IN_1_]] : f32 +// CHECK: linalg.yield [[VAR_10_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_5_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_9_]] in writable [[VAR_reinterpret_cast_5_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -670,27 +671,28 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK-DAG: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_3_]], [[VAR_subview_4_]] : memref> to memref> -// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]], [[VAR_7_]] : tensor<32xf32>, tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_8_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_8_]] : tensor<32xf32>, tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32, [[IN_2_:%.+]]: f32): -// CHECK: [[VAR_9_:%.+]] = math.powf [[IN_0_]], [[IN_1_]] : f32 -// CHECK: linalg.yield [[VAR_9_]] : f32 +// CHECK: [[VAR_10_:%.+]] = math.powf [[IN_0_]], [[IN_1_]] : f32 +// CHECK: linalg.yield [[VAR_10_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_5_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_9_]] in writable [[VAR_reinterpret_cast_5_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -704,20 +706,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.absf [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.absf [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -731,20 +734,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.sin [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.sin [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -758,20 +762,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.cos [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.cos [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -785,20 +790,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.tan [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.tan [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -812,20 +818,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.asin [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.asin [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -839,20 +846,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.acos [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.acos [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -866,20 +874,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.atan [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.atan [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -893,20 +902,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.sinh [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.sinh [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -920,20 +930,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.cosh [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.cosh [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -947,20 +958,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.tanh [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.tanh [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -974,20 +986,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.asinh [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.asinh [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1001,20 +1014,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.acosh [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.acosh [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1028,20 +1042,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.atanh [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.atanh [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1055,20 +1070,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.log [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.log [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1082,20 +1098,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.log10 [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.log10 [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1109,20 +1126,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.log1p [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.log1p [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1136,20 +1154,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.exp [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.exp [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1163,20 +1182,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.exp2 [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.exp2 [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1190,20 +1210,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.erf [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.erf [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1217,20 +1238,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.sqrt [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.sqrt [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1244,20 +1266,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.rsqrt [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.rsqrt [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1271,20 +1294,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.ceil [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.ceil [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1298,20 +1322,21 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.floor [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.floor [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> @@ -1325,19 +1350,20 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_32_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<32xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<32xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]] : tensor<32xf32>) outs([[VAR_6_]] : tensor<32xf32>) { +// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<32xf32> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]] : tensor<32xf32>) outs([[VAR_7_]] : tensor<32xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): -// CHECK: [[VAR_8_:%.+]] = math.trunc [[IN_0_]] : f32 -// CHECK: linalg.yield [[VAR_8_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.trunc [[IN_0_]] : f32 +// CHECK: linalg.yield [[VAR_9_]] : f32 // CHECK: } -> tensor<32xf32> // CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () +// CHECK: bufferization.materialize_in_destination [[VAR_8_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<32xf32>, memref<32xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/kernel-01-vector-add.mlir b/test/Conversion/StructuredToMemref/kernel-01-vector-add.mlir index 80a4cd4e..7ccca578 100644 --- a/test/Conversion/StructuredToMemref/kernel-01-vector-add.mlir +++ b/test/Conversion/StructuredToMemref/kernel-01-vector-add.mlir @@ -35,28 +35,29 @@ module { // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_1_]], [[CST_1024_1_]] : index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_1_]] : index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<1024xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<1024xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK-DAG: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<1024xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<1024xf32> // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<1024xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<1024xf32> to memref> +// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<1024xf32> to memref> // CHECK: memref.copy [[VAR_subview_3_]], [[VAR_subview_4_]] : memref> to memref> -// CHECK: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<1024xf32> -// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]], [[VAR_7_]] : tensor<1024xf32>, tensor<1024xf32>) outs([[VAR_6_]] : tensor<1024xf32>) { +// CHECK: [[VAR_8_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<1024xf32> +// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_8_]] : tensor<1024xf32>, tensor<1024xf32>) outs([[VAR_7_]] : tensor<1024xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32, [[IN_2_:%.+]]: f32): -// CHECK: [[VAR_9_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : f32 -// CHECK: linalg.yield [[VAR_9_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : f32 +// CHECK: linalg.yield [[VAR_10_]] : f32 // CHECK: } -> tensor<1024xf32> // CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_8_]][0] {{.}}[[VAR_5_]]{{.}} [1] : tensor<1024xf32> to tensor -// CHECK: [[VAR_subview_6_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_5_]]{{.}} [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_9_]][0] {{.}}[[VAR_6_]]{{.}} [1] : tensor<1024xf32> to tensor +// CHECK: [[VAR_subview_6_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> // CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_6_]] : (tensor, memref>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/kernel-02-fused-softmax.mlir b/test/Conversion/StructuredToMemref/kernel-02-fused-softmax.mlir index 4a9345e4..f21125c3 100644 --- a/test/Conversion/StructuredToMemref/kernel-02-fused-softmax.mlir +++ b/test/Conversion/StructuredToMemref/kernel-02-fused-softmax.mlir @@ -42,62 +42,63 @@ module { // CHECK-LABEL: func.func @softmax_kernel_012345 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32) { // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0xFF800000 : f32 // CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_2_]] : i32 // CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [128], strides: [1] : memref<*xf32> to memref<128xf32, strided<[1], offset: ?>> // CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_128_]] : index +// CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_128_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxsi [[VAR_3_]], [[CST_0_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128xf32> -// CHECK: [[VAR_4_:%.+]] = arith.cmpi slt, [[VAR_3_]], [[CST_128_]] : index -// CHECK: scf.if [[VAR_4_]] { -// CHECK: linalg.fill ins([[CST_0_]] : f32) outs([[RES_]] : memref<128xf32>) +// CHECK: [[VAR_5_:%.+]] = arith.cmpi slt, [[VAR_4_]], [[CST_128_]] : index +// CHECK: scf.if [[VAR_5_]] { +// CHECK: linalg.fill ins([[CST_0_1_]] : f32) outs([[RES_]] : memref<128xf32>) // CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_3_]]{{.}} [1] : memref<128xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_3_]]{{.}} [1] : memref<128xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_4_]]{{.}} [1] : memref<128xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_4_]]{{.}} [1] : memref<128xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_1_]] : memref> to memref> -// CHECK-DAG: [[VAR_5_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = bufferization.alloc_tensor() : tensor -// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_]] into [[VAR_6_]][] : tensor -// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_5_]] : tensor<128xf32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] +// CHECK-DAG: [[VAR_6_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = bufferization.alloc_tensor() : tensor +// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_1_]] into [[VAR_7_]][] : tensor +// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_6_]] : tensor<128xf32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] // CHECK: ([[in_:.+]]: f32, [[init_:.+]]: f32) { -// CHECK: [[VAR_16_:%.+]] = arith.maximumf [[in_]], [[init_]] : f32 -// CHECK: linalg.yield [[VAR_16_]] : f32 +// CHECK: [[VAR_17_:%.+]] = arith.maximumf [[in_]], [[init_]] : f32 +// CHECK: linalg.yield [[VAR_17_]] : f32 // CHECK: } // CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor -// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor<128xf32> -// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[VAR_extracted_]] : f32) outs([[VAR_7_]] : tensor<128xf32>) -> tensor<128xf32> -// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_5_]], [[VAR_8_]] : tensor<128xf32>, tensor<128xf32>) outs([[VAR_5_]] : tensor<128xf32>) { +// CHECK-DAG: [[VAR_8_:%.+]] = tensor.empty() : tensor<128xf32> +// CHECK: [[VAR_9_:%.+]] = linalg.fill ins([[VAR_extracted_]] : f32) outs([[VAR_8_]] : tensor<128xf32>) -> tensor<128xf32> +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]], [[VAR_9_]] : tensor<128xf32>, tensor<128xf32>) outs([[VAR_6_]] : tensor<128xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32, [[IN_2_:%.+]]: f32): -// CHECK: [[VAR_16_1_:%.+]] = arith.subf [[IN_0_]], [[IN_1_]] : f32 -// CHECK: linalg.yield [[VAR_16_1_]] : f32 +// CHECK: [[VAR_17_1_:%.+]] = arith.subf [[IN_0_]], [[IN_1_]] : f32 +// CHECK: linalg.yield [[VAR_17_1_]] : f32 // CHECK: } -> tensor<128xf32> -// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]] : tensor<128xf32>) outs([[VAR_9_]] : tensor<128xf32>) { +// CHECK: [[VAR_11_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_10_]] : tensor<128xf32>) outs([[VAR_10_]] : tensor<128xf32>) { // CHECK: ^bb0([[IN_3_:%.+]]: f32, [[IN_4_:%.+]]: f32): -// CHECK: [[VAR_16_2_:%.+]] = math.exp [[IN_3_]] : f32 -// CHECK: linalg.yield [[VAR_16_2_]] : f32 +// CHECK: [[VAR_17_2_:%.+]] = math.exp [[IN_3_]] : f32 +// CHECK: linalg.yield [[VAR_17_2_]] : f32 // CHECK: } -> tensor<128xf32> -// CHECK: [[VAR_11_:%.+]] = bufferization.alloc_tensor() : tensor -// CHECK: [[VAR_inserted_2_:%.+]] = tensor.insert [[CST_0_dot_000000_]] into [[VAR_11_]][] : tensor -// CHECK: [[VAR_reduced_3_:%.+]] = linalg.reduce ins([[VAR_10_]] : tensor<128xf32>) outs([[VAR_inserted_2_]] : tensor) dimensions = [0] +// CHECK: [[VAR_12_:%.+]] = bufferization.alloc_tensor() : tensor +// CHECK: [[VAR_inserted_2_:%.+]] = tensor.insert [[CST_0_dot_000000_]] into [[VAR_12_]][] : tensor +// CHECK: [[VAR_reduced_3_:%.+]] = linalg.reduce ins([[VAR_11_]] : tensor<128xf32>) outs([[VAR_inserted_2_]] : tensor) dimensions = [0] // CHECK: ([[IN_3_:.+]]: f32, [[init_:.+]]: f32) { -// CHECK: [[VAR_16_3_:%.+]] = arith.addf [[IN_3_]], [[init_]] : f32 -// CHECK: linalg.yield [[VAR_16_3_]] : f32 +// CHECK: [[VAR_17_3_:%.+]] = arith.addf [[IN_3_]], [[init_]] : f32 +// CHECK: linalg.yield [[VAR_17_3_]] : f32 // CHECK: } // CHECK: [[VAR_extracted_4_:%.+]] = tensor.extract [[VAR_reduced_3_]][] : tensor -// CHECK: [[VAR_12_:%.+]] = linalg.fill ins([[VAR_extracted_4_]] : f32) outs([[VAR_7_]] : tensor<128xf32>) -> tensor<128xf32> -// CHECK: [[VAR_13_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_10_]], [[VAR_12_]] : tensor<128xf32>, tensor<128xf32>) outs([[VAR_10_]] : tensor<128xf32>) { +// CHECK: [[VAR_13_:%.+]] = linalg.fill ins([[VAR_extracted_4_]] : f32) outs([[VAR_8_]] : tensor<128xf32>) -> tensor<128xf32> +// CHECK: [[VAR_14_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_11_]], [[VAR_13_]] : tensor<128xf32>, tensor<128xf32>) outs([[VAR_11_]] : tensor<128xf32>) { // CHECK: ^bb0([[IN_5_:%.+]]: f32, [[IN_6_:%.+]]: f32, [[IN_7_:%.+]]: f32): -// CHECK: [[VAR_16_4_:%.+]] = arith.divf [[IN_5_]], [[IN_6_]] : f32 -// CHECK: linalg.yield [[VAR_16_4_]] : f32 +// CHECK: [[VAR_17_4_:%.+]] = arith.divf [[IN_5_]], [[IN_6_]] : f32 +// CHECK: linalg.yield [[VAR_17_4_]] : f32 // CHECK: } -> tensor<128xf32> -// CHECK: [[VAR_14_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_3_]] : i32 -// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: [128], strides: [1] : memref<*xf32> to memref<128xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_13_]][0] {{.}}[[VAR_3_]]{{.}} [1] : tensor<128xf32> to tensor -// CHECK: [[VAR_subview_6_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_3_]]{{.}} [1] : memref<128xf32, strided<[1], offset: ?>> to memref> +// CHECK: [[VAR_15_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index +// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_16_]]{{.}}, sizes: [128], strides: [1] : memref<*xf32> to memref<128xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_14_]][0] {{.}}[[VAR_4_]]{{.}} [1] : tensor<128xf32> to tensor +// CHECK: [[VAR_subview_6_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_4_]]{{.}} [1] : memref<128xf32, strided<[1], offset: ?>> to memref> // CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_6_]] : (tensor, memref>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/kernel-03-matrix-multiplication.mlir b/test/Conversion/StructuredToMemref/kernel-03-matrix-multiplication.mlir index e5f6e3bf..52489552 100644 --- a/test/Conversion/StructuredToMemref/kernel-03-matrix-multiplication.mlir +++ b/test/Conversion/StructuredToMemref/kernel-03-matrix-multiplication.mlir @@ -108,10 +108,10 @@ module { // CHECK-DAG: [[CST_127_:%.+]] = arith.constant 127 : i32 // CHECK-DAG: [[CST_255_:%.+]] = arith.constant 255 : i32 // CHECK-DAG: [[CST_63_:%.+]] = arith.constant 63 : i32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_128_1_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x256xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<128x256xf32>) -> tensor<128x256xf32> @@ -155,31 +155,31 @@ module { // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_30_:%.+]] = arith.index_cast [[VAR_29_]] : i32 to index // CHECK-DAG: [[VAR_31_:%.+]]:3 = scf.for [[VAR_arg18_:%.+]] = [[CST_0_]] to [[VAR_7_]] step [[CST_1_]] iter_args([[VAR_arg19_:%.+]] = [[VAR_1_]], [[VAR_arg20_:%.+]] = [[VAR_22_]], [[VAR_arg21_:%.+]] = [[CST_0_1_]]) -> (tensor<128x256xf32>, index, index) : i32 { -// CHECK-DAG: [[VAR_49_:%.+]] = arith.addi [[VAR_arg21_]], [[VAR_26_]] : index +// CHECK-DAG: [[VAR_51_:%.+]] = arith.addi [[VAR_arg21_]], [[VAR_26_]] : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_49_]]{{.}}, sizes: [64, 256], strides: {{.}}[[VAR_24_]], [[VAR_25_]]{{.}} : memref<*xbf16> to memref<64x256xbf16, strided<[?, ?], offset: ?>> +// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_51_]]{{.}}, sizes: [64, 256], strides: {{.}}[[VAR_24_]], [[VAR_25_]]{{.}} : memref<*xbf16> to memref<64x256xbf16, strided<[?, ?], offset: ?>> // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg20_]]{{.}}, sizes: [128, 64], strides: {{.}}[[VAR_21_]], [[VAR_23_]]{{.}} : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> // CHECK: memref.copy [[VAR_reinterpret_cast_1_]], [[RES_]] : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16> -// CHECK-DAG: [[VAR_50_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> +// CHECK-DAG: [[VAR_52_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<64x256xbf16> // CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<64x256xbf16, strided<[?, ?], offset: ?>> to memref<64x256xbf16> -// CHECK: [[VAR_51_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<64x256xbf16> -// CHECK: [[VAR_52_:%.+]] = linalg.matmul ins([[VAR_50_]], [[VAR_51_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_1_]] : tensor<128x256xf32>) -> tensor<128x256xf32> -// CHECK: [[VAR_53_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg19_]], [[VAR_52_]] : tensor<128x256xf32>, tensor<128x256xf32>) outs([[VAR_arg19_]] : tensor<128x256xf32>) { +// CHECK: [[VAR_53_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<64x256xbf16> +// CHECK: [[VAR_54_:%.+]] = linalg.matmul ins([[VAR_52_]], [[VAR_53_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_1_]] : tensor<128x256xf32>) -> tensor<128x256xf32> +// CHECK: [[VAR_55_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg19_]], [[VAR_54_]] : tensor<128x256xf32>, tensor<128x256xf32>) outs([[VAR_arg19_]] : tensor<128x256xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32, [[IN_2_:%.+]]: f32): -// CHECK: [[VAR_56_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : f32 -// CHECK: linalg.yield [[VAR_56_]] : f32 +// CHECK: [[VAR_58_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : f32 +// CHECK: linalg.yield [[VAR_58_]] : f32 // CHECK: } -> tensor<128x256xf32> -// CHECK-DAG: [[VAR_54_:%.+]] = arith.addi [[VAR_arg20_]], [[VAR_28_]] : index -// CHECK-DAG: [[VAR_55_:%.+]] = arith.addi [[VAR_arg21_]], [[VAR_30_]] : index -// CHECK: scf.yield [[VAR_53_]], [[VAR_54_]], [[VAR_55_]] : tensor<128x256xf32>, index, index +// CHECK-DAG: [[VAR_56_:%.+]] = arith.addi [[VAR_arg20_]], [[VAR_28_]] : index +// CHECK-DAG: [[VAR_57_:%.+]] = arith.addi [[VAR_arg21_]], [[VAR_30_]] : index +// CHECK: scf.yield [[VAR_55_]], [[VAR_56_]], [[VAR_57_]] : tensor<128x256xf32>, index, index // CHECK: } // CHECK: [[VAR_32_:%.+]] = tensor.empty() : tensor<128x256xbf16> // CHECK: [[VAR_33_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_31_]]#0 : tensor<128x256xf32>) outs([[VAR_32_]] : tensor<128x256xbf16>) { // CHECK: ^bb0([[IN_3_:%.+]]: f32, [[IN_4_:%.+]]: bf16): -// CHECK: [[VAR_49_1_:%.+]] = arith.truncf [[IN_3_]] : f32 to bf16 -// CHECK: linalg.yield [[VAR_49_1_]] : bf16 +// CHECK: [[VAR_51_1_:%.+]] = arith.truncf [[IN_3_]] : f32 to bf16 +// CHECK: linalg.yield [[VAR_51_1_]] : bf16 // CHECK: } -> tensor<128x256xbf16> // CHECK: [[VAR_34_:%.+]] = arith.index_cast [[PARAM_10_]] : i32 to index // CHECK-DAG: [[VAR_35_:%.+]] = arith.muli [[VAR_18_]], [[VAR_34_]] : index @@ -190,15 +190,17 @@ module { // CHECK-DAG: [[VAR_39_:%.+]] = arith.addi [[VAR_18_]], [[CST_128_1_]] : index // CHECK-DAG: [[VAR_40_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index // CHECK: [[VAR_41_:%.+]] = arith.minsi [[VAR_39_]], [[VAR_40_]] : index -// CHECK-DAG: [[VAR_42_:%.+]] = arith.subi [[VAR_41_]], [[VAR_18_]] : index -// CHECK-DAG: [[VAR_43_:%.+]] = arith.addi [[VAR_20_]], [[CST_256_1_]] : index -// CHECK-DAG: [[VAR_44_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_45_:%.+]] = arith.minsi [[VAR_43_]], [[VAR_44_]] : index -// CHECK-DAG: [[VAR_46_:%.+]] = arith.subi [[VAR_45_]], [[VAR_20_]] : index -// CHECK-DAG: [[VAR_47_:%.+]] = arith.minsi [[VAR_42_]], [[CST_128_1_]] : index -// CHECK: [[VAR_48_:%.+]] = arith.minsi [[VAR_46_]], [[CST_256_1_]] : index -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_33_]][0, 0] {{.}}[[VAR_47_]], [[VAR_48_]]{{.}} [1, 1] : tensor<128x256xbf16> to tensor -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0, 0] {{.}}[[VAR_47_]], [[VAR_48_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[?, ?], offset: ?>> to memref> +// CHECK: [[VAR_42_:%.+]] = arith.maxsi [[VAR_41_]], [[VAR_18_]] : index +// CHECK-DAG: [[VAR_43_:%.+]] = arith.subi [[VAR_42_]], [[VAR_18_]] : index +// CHECK-DAG: [[VAR_44_:%.+]] = arith.addi [[VAR_20_]], [[CST_256_1_]] : index +// CHECK-DAG: [[VAR_45_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_46_:%.+]] = arith.minsi [[VAR_44_]], [[VAR_45_]] : index +// CHECK: [[VAR_47_:%.+]] = arith.maxsi [[VAR_46_]], [[VAR_20_]] : index +// CHECK-DAG: [[VAR_48_:%.+]] = arith.subi [[VAR_47_]], [[VAR_20_]] : index +// CHECK-DAG: [[VAR_49_:%.+]] = arith.minsi [[VAR_43_]], [[CST_128_1_]] : index +// CHECK: [[VAR_50_:%.+]] = arith.minsi [[VAR_48_]], [[CST_256_1_]] : index +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_33_]][0, 0] {{.}}[[VAR_49_]], [[VAR_50_]]{{.}} [1, 1] : tensor<128x256xbf16> to tensor +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0, 0] {{.}}[[VAR_49_]], [[VAR_50_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[?, ?], offset: ?>> to memref> // CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_]] : (tensor, memref>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/kernel-05-layer-norm-dwdb.mlir b/test/Conversion/StructuredToMemref/kernel-05-layer-norm-dwdb.mlir index b4da393c..a32ea717 100644 --- a/test/Conversion/StructuredToMemref/kernel-05-layer-norm-dwdb.mlir +++ b/test/Conversion/StructuredToMemref/kernel-05-layer-norm-dwdb.mlir @@ -74,76 +74,79 @@ module { // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index // CHECK-DAG: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg12_:%.+]] = [[CST_0_]] to [[PARAM_4_]] step [[CST_256_]] iter_args([[VAR_arg13_:%.+]] = [[VAR_1_]], [[VAR_arg14_:%.+]] = [[VAR_1_]]) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { -// CHECK-DAG: [[VAR_11_:%.+]] = arith.index_cast [[VAR_arg12_]] : i32 to index -// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_13_:%.+]] = arith.muli [[VAR_11_]], [[VAR_12_]] : index -// CHECK: [[VAR_14_:%.+]] = arith.addi [[VAR_13_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_14_]]{{.}}, sizes: [256, 256], strides: {{.}}[[VAR_12_]], 1] : memref<*xf32> to memref<256x256xf32, strided<[?, 1], offset: ?>> -// CHECK-DAG: [[VAR_15_:%.+]] = arith.addi [[VAR_11_]], [[CST_256_1_]] : index -// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_15_]], [[VAR_16_]] : index -// CHECK-DAG: [[VAR_18_:%.+]] = arith.subi [[VAR_17_]], [[VAR_11_]] : index -// CHECK-DAG: [[VAR_19_:%.+]] = arith.addi [[VAR_3_]], [[CST_256_1_]] : index -// CHECK: [[VAR_20_:%.+]] = arith.minsi [[VAR_19_]], [[VAR_12_]] : index -// CHECK-DAG: [[VAR_21_:%.+]] = arith.subi [[VAR_20_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_22_:%.+]] = arith.minsi [[VAR_18_]], [[CST_256_1_]] : index +// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[VAR_arg12_]] : i32 to index +// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_14_:%.+]] = arith.muli [[VAR_12_]], [[VAR_13_]] : index +// CHECK: [[VAR_15_:%.+]] = arith.addi [[VAR_14_]], [[VAR_3_]] : index +// CHECK-DAG: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: [256, 256], strides: {{.}}[[VAR_13_]], 1] : memref<*xf32> to memref<256x256xf32, strided<[?, 1], offset: ?>> +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_12_]], [[CST_256_1_]] : index +// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_18_:%.+]] = arith.minsi [[VAR_16_]], [[VAR_17_]] : index +// CHECK: [[VAR_19_:%.+]] = arith.maxsi [[VAR_18_]], [[VAR_12_]] : index +// CHECK-DAG: [[VAR_20_:%.+]] = arith.subi [[VAR_19_]], [[VAR_12_]] : index +// CHECK-DAG: [[VAR_21_:%.+]] = arith.addi [[VAR_3_]], [[CST_256_1_]] : index +// CHECK: [[VAR_22_:%.+]] = arith.minsi [[VAR_21_]], [[VAR_13_]] : index +// CHECK: [[VAR_23_:%.+]] = arith.maxsi [[VAR_22_]], [[VAR_3_]] : index +// CHECK-DAG: [[VAR_24_:%.+]] = arith.subi [[VAR_23_]], [[VAR_3_]] : index +// CHECK-DAG: [[VAR_25_:%.+]] = arith.minsi [[VAR_20_]], [[CST_256_1_]] : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.minsi [[VAR_21_]], [[CST_256_1_]] : index +// CHECK-DAG: [[VAR_26_:%.+]] = arith.minsi [[VAR_24_]], [[CST_256_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<256x256xf32> -// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpi slt, [[VAR_22_]], [[CST_256_1_]] : index -// CHECK: [[VAR_25_:%.+]] = arith.cmpi slt, [[VAR_23_]], [[CST_256_1_]] : index -// CHECK: [[VAR_26_:%.+]] = arith.ori [[VAR_24_]], [[VAR_25_]] : i1 -// CHECK: scf.if [[VAR_26_]] { +// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpi slt, [[VAR_25_]], [[CST_256_1_]] : index +// CHECK: [[VAR_28_:%.+]] = arith.cmpi slt, [[VAR_26_]], [[CST_256_1_]] : index +// CHECK: [[VAR_29_:%.+]] = arith.ori [[VAR_27_]], [[VAR_28_]] : i1 +// CHECK: scf.if [[VAR_29_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_]] : memref<256x256xf32>) // CHECK: } -// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[VAR_reinterpret_cast_4_]][0, 0] {{.}}[[VAR_22_]], [[VAR_23_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_22_]], [[VAR_23_]]{{.}} [1, 1] : memref<256x256xf32> to memref> +// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[VAR_reinterpret_cast_4_]][0, 0] {{.}}[[VAR_25_]], [[VAR_26_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_25_]], [[VAR_26_]]{{.}} [1, 1] : memref<256x256xf32> to memref> // CHECK: memref.copy [[VAR_subview_5_]], [[VAR_subview_6_]] : memref> to memref> -// CHECK: [[VAR_27_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256x256xf32> -// CHECK: [[VAR_28_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg13_]], [[VAR_27_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg13_]] : tensor<256x256xf32>) { +// CHECK: [[VAR_30_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256x256xf32> +// CHECK: [[VAR_31_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg13_]], [[VAR_30_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg13_]] : tensor<256x256xf32>) { // CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32, [[IN_2_:%.+]]: f32): -// CHECK: [[VAR_31_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : f32 -// CHECK: linalg.yield [[VAR_31_]] : f32 +// CHECK: [[VAR_34_:%.+]] = arith.addf [[IN_0_]], [[IN_1_]] : f32 +// CHECK: linalg.yield [[VAR_34_]] : f32 // CHECK: } -> tensor<256x256xf32> -// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_14_]]{{.}}, sizes: [256, 256], strides: {{.}}[[VAR_12_]], 1] : memref<*xf32> to memref<256x256xf32, strided<[?, 1], offset: ?>> +// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: [256, 256], strides: {{.}}[[VAR_13_]], 1] : memref<*xf32> to memref<256x256xf32, strided<[?, 1], offset: ?>> // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<256x256xf32> -// CHECK: scf.if [[VAR_26_]] { +// CHECK: scf.if [[VAR_29_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_1_]] : memref<256x256xf32>) // CHECK: } -// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0, 0] {{.}}[[VAR_22_]], [[VAR_23_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_1_]][0, 0] {{.}}[[VAR_22_]], [[VAR_23_]]{{.}} [1, 1] : memref<256x256xf32> to memref> +// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0, 0] {{.}}[[VAR_25_]], [[VAR_26_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_1_]][0, 0] {{.}}[[VAR_25_]], [[VAR_26_]]{{.}} [1, 1] : memref<256x256xf32> to memref> // CHECK: memref.copy [[VAR_subview_9_]], [[VAR_subview_10_]] : memref> to memref> -// CHECK: [[VAR_29_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256x256xf32> -// CHECK: [[VAR_30_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg14_]], [[VAR_29_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg14_]] : tensor<256x256xf32>) { +// CHECK: [[VAR_32_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256x256xf32> +// CHECK: [[VAR_33_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg14_]], [[VAR_32_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg14_]] : tensor<256x256xf32>) { // CHECK: ^bb0([[IN_3_:%.+]]: f32, [[IN_4_:%.+]]: f32, [[IN_5_:%.+]]: f32): -// CHECK: [[VAR_31_1_:%.+]] = arith.addf [[IN_3_]], [[IN_4_]] : f32 -// CHECK: linalg.yield [[VAR_31_1_]] : f32 +// CHECK: [[VAR_34_1_:%.+]] = arith.addf [[IN_3_]], [[IN_4_]] : f32 +// CHECK: linalg.yield [[VAR_34_1_]] : f32 // CHECK: } -> tensor<256x256xf32> -// CHECK: scf.yield [[VAR_28_]], [[VAR_30_]] : tensor<256x256xf32>, tensor<256x256xf32> +// CHECK: scf.yield [[VAR_31_]], [[VAR_33_]] : tensor<256x256xf32>, tensor<256x256xf32> // CHECK: } // CHECK: [[VAR_5_:%.+]] = tensor.empty() : tensor<256xf32> // CHECK: [[VAR_6_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_5_]] : tensor<256xf32>) -> tensor<256xf32> // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_4_]]#0 : tensor<256x256xf32>) outs([[VAR_6_]] : tensor<256xf32>) dimensions = [0] // CHECK: ([[IN_3_:.+]]: f32, [[init_:.+]]: f32) { -// CHECK: [[VAR_11_1_:%.+]] = arith.addf [[IN_3_]], [[init_]] : f32 -// CHECK: linalg.yield [[VAR_11_1_]] : f32 +// CHECK: [[VAR_12_1_:%.+]] = arith.addf [[IN_3_]], [[init_]] : f32 +// CHECK: linalg.yield [[VAR_12_1_]] : f32 // CHECK: } // CHECK: [[VAR_reduced_0_:%.+]] = linalg.reduce ins([[VAR_4_]]#1 : tensor<256x256xf32>) outs([[VAR_6_]] : tensor<256xf32>) dimensions = [0] // CHECK: ([[IN_3_:.+]]: f32, [[init_:.+]]: f32) { -// CHECK: [[VAR_11_2_:%.+]] = arith.addf [[IN_3_]], [[init_]] : f32 -// CHECK: linalg.yield [[VAR_11_2_]] : f32 +// CHECK: [[VAR_12_2_:%.+]] = arith.addf [[IN_3_]], [[init_]] : f32 +// CHECK: linalg.yield [[VAR_12_2_]] : f32 // CHECK: } // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> // CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_3_]], [[CST_256_1_]] : index // CHECK-DAG: [[VAR_8_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index // CHECK: [[VAR_9_:%.+]] = arith.minsi [[VAR_7_]], [[VAR_8_]] : index -// CHECK: [[VAR_10_:%.+]] = arith.subi [[VAR_9_]], [[VAR_3_]] : index -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_reduced_]][0] {{.}}[[VAR_10_]]{{.}} [1] : tensor<256xf32> to tensor -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_10_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK: [[VAR_10_:%.+]] = arith.maxsi [[VAR_9_]], [[VAR_3_]] : index +// CHECK: [[VAR_11_:%.+]] = arith.subi [[VAR_10_]], [[VAR_3_]] : index +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_reduced_]][0] {{.}}[[VAR_11_]]{{.}} [1] : tensor<256xf32> to tensor +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_11_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> // CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_]] : (tensor, memref>) -> () // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_3_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_extracted_slice_2_:%.+]] = tensor.extract_slice [[VAR_reduced_0_]][0] {{.}}[[VAR_10_]]{{.}} [1] : tensor<256xf32> to tensor -// CHECK: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_10_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_extracted_slice_2_:%.+]] = tensor.extract_slice [[VAR_reduced_0_]][0] {{.}}[[VAR_11_]]{{.}} [1] : tensor<256xf32> to tensor +// CHECK: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_11_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> // CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_2_]] in writable [[VAR_subview_3_]] : (tensor, memref>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/kernel-05-layer-norm-fwd.mlir b/test/Conversion/StructuredToMemref/kernel-05-layer-norm-fwd.mlir index 5c3b843a..4d41da34 100644 --- a/test/Conversion/StructuredToMemref/kernel-05-layer-norm-fwd.mlir +++ b/test/Conversion/StructuredToMemref/kernel-05-layer-norm-fwd.mlir @@ -117,22 +117,23 @@ module { // CHECK-DAG: [[VAR_22_:%.+]] = arith.addi [[VAR_20_1_]], [[CST_256_1_]] : index // CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index // CHECK: [[VAR_24_:%.+]] = arith.minsi [[VAR_22_]], [[VAR_23_]] : index -// CHECK-DAG: [[VAR_25_:%.+]] = arith.subi [[VAR_24_]], [[VAR_20_1_]] : index +// CHECK: [[VAR_25_:%.+]] = arith.maxsi [[VAR_24_]], [[VAR_20_1_]] : index +// CHECK-DAG: [[VAR_26_:%.+]] = arith.subi [[VAR_25_]], [[VAR_20_1_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<256xf32> -// CHECK: [[VAR_26_:%.+]] = arith.cmpi slt, [[VAR_25_]], [[CST_256_1_]] : index -// CHECK: scf.if [[VAR_26_]] { +// CHECK: [[VAR_27_:%.+]] = arith.cmpi slt, [[VAR_26_]], [[CST_256_1_]] : index +// CHECK: scf.if [[VAR_27_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_]] : memref<256xf32>) // CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_25_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_25_]]{{.}} [1] : memref<256xf32> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_26_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_26_]]{{.}} [1] : memref<256xf32> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_6_]] : memref> to memref> -// CHECK: [[VAR_27_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256xf32> -// CHECK: [[VAR_28_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg16_]], [[VAR_27_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_arg16_]] : tensor<256xf32>) { +// CHECK: [[VAR_28_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256xf32> +// CHECK: [[VAR_29_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg16_]], [[VAR_28_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_arg16_]] : tensor<256xf32>) { // CHECK: ^bb0([[IN_1_:%.+]]: f32, [[IN_2_:%.+]]: f32, [[IN_3_:%.+]]: f32): -// CHECK: [[VAR_29_:%.+]] = arith.addf [[IN_1_]], [[IN_2_]] : f32 -// CHECK: linalg.yield [[VAR_29_]] : f32 +// CHECK: [[VAR_30_:%.+]] = arith.addf [[IN_1_]], [[IN_2_]] : f32 +// CHECK: linalg.yield [[VAR_30_]] : f32 // CHECK: } -> tensor<256xf32> -// CHECK: scf.yield [[VAR_28_]] : tensor<256xf32> +// CHECK: scf.yield [[VAR_29_]] : tensor<256xf32> // CHECK: } // CHECK: [[VAR_8_:%.+]] = bufferization.alloc_tensor() : tensor // CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_dot_000000_]] into [[VAR_8_]][] : tensor @@ -149,14 +150,14 @@ module { // CHECK-DAG: [[VAR_20_3_:%.+]] = linalg.fill ins([[VAR_arg15_1_]] : i32) outs([[VAR_4_]] : tensor<256xi32>) -> tensor<256xi32> // CHECK: [[VAR_21_2_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_20_3_]], [[VAR_5_]] : tensor<256xi32>, tensor<256xi32>) outs([[VAR_20_3_]] : tensor<256xi32>) { // CHECK: ^bb0([[IN_4_:%.+]]: i32, [[IN_5_:%.+]]: i32, [[IN_6_:%.+]]: i32): -// CHECK: [[VAR_36_:%.+]] = arith.addi [[IN_4_]], [[IN_5_]] : i32 -// CHECK: linalg.yield [[VAR_36_]] : i32 +// CHECK: [[VAR_37_:%.+]] = arith.addi [[IN_4_]], [[IN_5_]] : i32 +// CHECK: linalg.yield [[VAR_37_]] : i32 // CHECK: } -> tensor<256xi32> // CHECK: [[VAR_22_1_:%.+]] = tensor.empty() : tensor<256xi1> // CHECK: [[VAR_23_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_21_2_]], [[VAR_6_]] : tensor<256xi32>, tensor<256xi32>) outs([[VAR_22_1_]] : tensor<256xi1>) { // CHECK: ^bb0([[IN_7_:%.+]]: i32, [[IN_8_:%.+]]: i32, [[IN_9_:%.+]]: i1): -// CHECK: [[VAR_36_1_:%.+]] = arith.cmpi slt, [[IN_7_]], [[IN_8_]] : i32 -// CHECK: linalg.yield [[VAR_36_1_]] : i1 +// CHECK: [[VAR_37_1_:%.+]] = arith.cmpi slt, [[IN_7_]], [[IN_8_]] : i32 +// CHECK: linalg.yield [[VAR_37_1_]] : i1 // CHECK: } -> tensor<256xi1> // CHECK: [[VAR_24_1_:%.+]] = arith.index_cast [[VAR_arg15_1_]] : i32 to index // CHECK: [[VAR_25_1_:%.+]] = arith.addi [[VAR_3_]], [[VAR_24_1_]] : index @@ -164,37 +165,38 @@ module { // CHECK-DAG: [[VAR_26_1_:%.+]] = arith.addi [[VAR_24_1_]], [[CST_256_1_]] : index // CHECK-DAG: [[VAR_27_1_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index // CHECK: [[VAR_28_1_:%.+]] = arith.minsi [[VAR_26_1_]], [[VAR_27_1_]] : index -// CHECK-DAG: [[VAR_29_1_:%.+]] = arith.subi [[VAR_28_1_]], [[VAR_24_1_]] : index +// CHECK: [[VAR_29_1_:%.+]] = arith.maxsi [[VAR_28_1_]], [[VAR_24_1_]] : index +// CHECK-DAG: [[VAR_30_1_:%.+]] = arith.subi [[VAR_29_1_]], [[VAR_24_1_]] : index // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<256xf32> -// CHECK: [[VAR_30_:%.+]] = arith.cmpi slt, [[VAR_29_1_]], [[CST_256_1_]] : index -// CHECK: scf.if [[VAR_30_]] { +// CHECK: [[VAR_31_:%.+]] = arith.cmpi slt, [[VAR_30_1_]], [[CST_256_1_]] : index +// CHECK: scf.if [[VAR_31_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_1_]] : memref<256xf32>) // CHECK: } -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_1_]][0] {{.}}[[VAR_29_1_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_1_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_29_1_]]{{.}} [1] : memref<256xf32> to memref> +// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_1_]][0] {{.}}[[VAR_30_1_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_6_1_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_30_1_]]{{.}} [1] : memref<256xf32> to memref> // CHECK: memref.copy [[VAR_subview_1_]], [[VAR_subview_6_1_]] : memref> to memref> -// CHECK: [[VAR_31_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256xf32> -// CHECK: [[VAR_32_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_31_]], [[VAR_11_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_31_]] : tensor<256xf32>) { +// CHECK: [[VAR_32_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256xf32> +// CHECK: [[VAR_33_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_32_]], [[VAR_11_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_32_]] : tensor<256xf32>) { // CHECK: ^bb0([[IN_10_:%.+]]: f32, [[IN_11_:%.+]]: f32, [[IN_12_:%.+]]: f32): -// CHECK: [[VAR_36_2_:%.+]] = arith.subf [[IN_10_]], [[IN_11_]] : f32 -// CHECK: linalg.yield [[VAR_36_2_]] : f32 +// CHECK: [[VAR_37_2_:%.+]] = arith.subf [[IN_10_]], [[IN_11_]] : f32 +// CHECK: linalg.yield [[VAR_37_2_]] : f32 // CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_33_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_23_1_]], [[VAR_32_]], [[VAR_1_]] : tensor<256xi1>, tensor<256xf32>, tensor<256xf32>) outs([[VAR_32_]] : tensor<256xf32>) { +// CHECK: [[VAR_34_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_23_1_]], [[VAR_33_]], [[VAR_1_]] : tensor<256xi1>, tensor<256xf32>, tensor<256xf32>) outs([[VAR_33_]] : tensor<256xf32>) { // CHECK: ^bb0([[IN_13_:%.+]]: i1, [[IN_14_:%.+]]: f32, [[IN_15_:%.+]]: f32, [[IN_16_:%.+]]: f32): -// CHECK: [[VAR_36_3_:%.+]] = arith.select [[IN_13_]], [[IN_14_]], [[IN_15_]] : f32 -// CHECK: linalg.yield [[VAR_36_3_]] : f32 +// CHECK: [[VAR_37_3_:%.+]] = arith.select [[IN_13_]], [[IN_14_]], [[IN_15_]] : f32 +// CHECK: linalg.yield [[VAR_37_3_]] : f32 // CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_34_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_33_]], [[VAR_33_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_33_]] : tensor<256xf32>) { +// CHECK: [[VAR_35_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_34_]], [[VAR_34_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_34_]] : tensor<256xf32>) { // CHECK: ^bb0([[IN_17_:%.+]]: f32, [[IN_18_:%.+]]: f32, [[IN_19_:%.+]]: f32): -// CHECK: [[VAR_36_4_:%.+]] = arith.mulf [[IN_17_]], [[IN_18_]] : f32 -// CHECK: linalg.yield [[VAR_36_4_]] : f32 +// CHECK: [[VAR_37_4_:%.+]] = arith.mulf [[IN_17_]], [[IN_18_]] : f32 +// CHECK: linalg.yield [[VAR_37_4_]] : f32 // CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_35_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg16_1_]], [[VAR_34_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_arg16_1_]] : tensor<256xf32>) { +// CHECK: [[VAR_36_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg16_1_]], [[VAR_35_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_arg16_1_]] : tensor<256xf32>) { // CHECK: ^bb0([[IN_20_:%.+]]: f32, [[IN_21_:%.+]]: f32, [[IN_22_:%.+]]: f32): -// CHECK: [[VAR_36_5_:%.+]] = arith.addf [[IN_20_]], [[IN_21_]] : f32 -// CHECK: linalg.yield [[VAR_36_5_]] : f32 +// CHECK: [[VAR_37_5_:%.+]] = arith.addf [[IN_20_]], [[IN_21_]] : f32 +// CHECK: linalg.yield [[VAR_37_5_]] : f32 // CHECK: } -> tensor<256xf32> -// CHECK: scf.yield [[VAR_35_]] : tensor<256xf32> +// CHECK: scf.yield [[VAR_36_]] : tensor<256xf32> // CHECK: } // CHECK: [[VAR_13_:%.+]] = bufferization.alloc_tensor() : tensor // CHECK: [[VAR_inserted_1_:%.+]] = tensor.insert [[CST_0_dot_000000_]] into [[VAR_13_]][] : tensor @@ -220,55 +222,56 @@ module { // CHECK-DAG: [[VAR_21_3_:%.+]] = arith.addi [[VAR_20_5_]], [[CST_256_1_]] : index // CHECK-DAG: [[VAR_22_2_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index // CHECK: [[VAR_23_2_:%.+]] = arith.minsi [[VAR_21_3_]], [[VAR_22_2_]] : index -// CHECK-DAG: [[VAR_24_2_:%.+]] = arith.subi [[VAR_23_2_]], [[VAR_20_5_]] : index +// CHECK: [[VAR_24_2_:%.+]] = arith.maxsi [[VAR_23_2_]], [[VAR_20_5_]] : index +// CHECK-DAG: [[VAR_25_2_:%.+]] = arith.subi [[VAR_24_2_]], [[VAR_20_5_]] : index // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_2_]][0] {{.}}[[VAR_24_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_2_:%.+]] = memref.subview [[RES_2_]][0] {{.}}[[VAR_24_2_]]{{.}} [1] : memref<256xf32> to memref> +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_2_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_6_2_:%.+]] = memref.subview [[RES_2_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32> to memref> // CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_6_2_]] : memref> to memref> -// CHECK-DAG: [[VAR_25_2_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<256xf32> +// CHECK-DAG: [[VAR_26_2_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<256xf32> // CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_3_]] to offset: {{.}}[[VAR_20_5_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0] {{.}}[[VAR_24_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_3_]][0] {{.}}[[VAR_24_2_]]{{.}} [1] : memref<256xf32> to memref> +// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_3_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32> to memref> // CHECK: memref.copy [[VAR_subview_9_]], [[VAR_subview_10_]] : memref> to memref> -// CHECK-DAG: [[VAR_26_2_:%.+]] = bufferization.to_tensor [[RES_3_]] restrict writable : memref<256xf32> -// CHECK-DAG: [[VAR_27_2_:%.+]] = arith.addi [[VAR_3_]], [[VAR_20_5_]] : index +// CHECK-DAG: [[VAR_27_2_:%.+]] = bufferization.to_tensor [[RES_3_]] restrict writable : memref<256xf32> +// CHECK-DAG: [[VAR_28_2_:%.+]] = arith.addi [[VAR_3_]], [[VAR_20_5_]] : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_11_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_2_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[VAR_reinterpret_cast_11_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_28_2_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() : memref<256xf32> -// CHECK-DAG: [[VAR_28_2_:%.+]] = arith.cmpi slt, [[VAR_24_2_]], [[CST_256_1_]] : index -// CHECK: scf.if [[VAR_28_2_]] { +// CHECK-DAG: [[VAR_29_2_:%.+]] = arith.cmpi slt, [[VAR_25_2_]], [[CST_256_1_]] : index +// CHECK: scf.if [[VAR_29_2_]] { // CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_4_]] : memref<256xf32>) // CHECK: } -// CHECK-DAG: [[VAR_subview_13_:%.+]] = memref.subview [[VAR_reinterpret_cast_11_]][0] {{.}}[[VAR_24_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_14_:%.+]] = memref.subview [[RES_4_]][0] {{.}}[[VAR_24_2_]]{{.}} [1] : memref<256xf32> to memref> +// CHECK-DAG: [[VAR_subview_13_:%.+]] = memref.subview [[VAR_reinterpret_cast_11_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_14_:%.+]] = memref.subview [[RES_4_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32> to memref> // CHECK: memref.copy [[VAR_subview_13_]], [[VAR_subview_14_]] : memref> to memref> -// CHECK: [[VAR_29_2_:%.+]] = bufferization.to_tensor [[RES_4_]] restrict writable : memref<256xf32> -// CHECK: [[VAR_30_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_29_2_]], [[VAR_11_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_29_2_]] : tensor<256xf32>) { +// CHECK: [[VAR_30_2_:%.+]] = bufferization.to_tensor [[RES_4_]] restrict writable : memref<256xf32> +// CHECK: [[VAR_31_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_30_2_]], [[VAR_11_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_30_2_]] : tensor<256xf32>) { // CHECK: ^bb0([[IN_23_:%.+]]: f32, [[IN_24_:%.+]]: f32, [[IN_25_:%.+]]: f32): -// CHECK: [[VAR_34_1_:%.+]] = arith.subf [[IN_23_]], [[IN_24_]] : f32 -// CHECK: linalg.yield [[VAR_34_1_]] : f32 +// CHECK: [[VAR_35_1_:%.+]] = arith.subf [[IN_23_]], [[IN_24_]] : f32 +// CHECK: linalg.yield [[VAR_35_1_]] : f32 // CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_31_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_30_1_]], [[VAR_19_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_30_1_]] : tensor<256xf32>) { +// CHECK: [[VAR_32_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_31_1_]], [[VAR_19_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_31_1_]] : tensor<256xf32>) { // CHECK: ^bb0([[IN_26_:%.+]]: f32, [[IN_27_:%.+]]: f32, [[IN_28_:%.+]]: f32): -// CHECK: [[VAR_34_2_:%.+]] = arith.mulf [[IN_26_]], [[IN_27_]] : f32 -// CHECK: linalg.yield [[VAR_34_2_]] : f32 +// CHECK: [[VAR_35_2_:%.+]] = arith.mulf [[IN_26_]], [[IN_27_]] : f32 +// CHECK: linalg.yield [[VAR_35_2_]] : f32 // CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_32_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_31_1_]], [[VAR_25_2_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_31_1_]] : tensor<256xf32>) { +// CHECK: [[VAR_33_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_32_1_]], [[VAR_26_2_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_32_1_]] : tensor<256xf32>) { // CHECK: ^bb0([[IN_29_:%.+]]: f32, [[IN_30_:%.+]]: f32, [[IN_31_:%.+]]: f32): -// CHECK: [[VAR_34_3_:%.+]] = arith.mulf [[IN_29_]], [[IN_30_]] : f32 -// CHECK: linalg.yield [[VAR_34_3_]] : f32 +// CHECK: [[VAR_35_3_:%.+]] = arith.mulf [[IN_29_]], [[IN_30_]] : f32 +// CHECK: linalg.yield [[VAR_35_3_]] : f32 // CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_33_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_32_1_]], [[VAR_26_2_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_32_1_]] : tensor<256xf32>) { +// CHECK: [[VAR_34_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_33_1_]], [[VAR_27_2_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_33_1_]] : tensor<256xf32>) { // CHECK: ^bb0([[IN_32_:%.+]]: f32, [[IN_33_:%.+]]: f32, [[IN_34_:%.+]]: f32): -// CHECK: [[VAR_34_4_:%.+]] = arith.addf [[IN_32_]], [[IN_33_]] : f32 -// CHECK: linalg.yield [[VAR_34_4_]] : f32 +// CHECK: [[VAR_35_4_:%.+]] = arith.addf [[IN_32_]], [[IN_33_]] : f32 +// CHECK: linalg.yield [[VAR_35_4_]] : f32 // CHECK: } -> tensor<256xf32> -// CHECK-DAG: [[VAR_reinterpret_cast_15_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_27_2_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_33_1_]][0] {{.}}[[VAR_24_2_]]{{.}} [1] : tensor<256xf32> to tensor -// CHECK: [[VAR_subview_16_:%.+]] = memref.subview [[VAR_reinterpret_cast_15_]][0] {{.}}[[VAR_24_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> +// CHECK-DAG: [[VAR_reinterpret_cast_15_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_28_2_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_34_1_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : tensor<256xf32> to tensor +// CHECK: [[VAR_subview_16_:%.+]] = memref.subview [[VAR_reinterpret_cast_15_]][0] {{.}}[[VAR_25_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> // CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_16_]] : (tensor, memref>) -> () // CHECK: } // CHECK: return diff --git a/test/Conversion/StructuredToMemref/masked_ldst_1d.mlir b/test/Conversion/StructuredToMemref/masked_ldst_1d.mlir index 26a9c261..29bd432a 100644 --- a/test/Conversion/StructuredToMemref/masked_ldst_1d.mlir +++ b/test/Conversion/StructuredToMemref/masked_ldst_1d.mlir @@ -23,23 +23,24 @@ module { // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0xFF80 : bf16 // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> // CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = arith.minsi [[VAR_0_]], [[CST_128_]] : index +// CHECK: [[VAR_1_:%.+]] = arith.minsi [[VAR_0_]], [[CST_128_]] : index +// CHECK-DAG: [[VAR_2_:%.+]] = arith.maxsi [[VAR_1_]], [[CST_0_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128xbf16> -// CHECK: [[VAR_2_:%.+]] = arith.cmpi slt, [[VAR_1_]], [[CST_128_]] : index -// CHECK: scf.if [[VAR_2_]] { -// CHECK: linalg.fill ins([[CST_0_]] : bf16) outs([[RES_]] : memref<128xbf16>) +// CHECK: [[VAR_3_:%.+]] = arith.cmpi slt, [[VAR_2_]], [[CST_128_]] : index +// CHECK: scf.if [[VAR_3_]] { +// CHECK: linalg.fill ins([[CST_0_1_]] : bf16) outs([[RES_]] : memref<128xbf16>) // CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_1_]]{{.}} [1] : memref<128xbf16, strided<[1]>> to memref> -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_1_]]{{.}} [1] : memref<128xbf16> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_2_]]{{.}} [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_2_]]{{.}} [1] : memref<128xbf16> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_1_]] : memref> to memref> -// CHECK: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128xbf16> -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_3_]][0] {{.}}[[VAR_1_]]{{.}} [1] : tensor<128xbf16> to tensor -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0] {{.}}[[VAR_1_]]{{.}} [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128xbf16> +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_4_]][0] {{.}}[[VAR_2_]]{{.}} [1] : tensor<128xbf16> to tensor +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0] {{.}}[[VAR_2_]]{{.}} [1] : memref<128xbf16, strided<[1]>> to memref> // CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_2_]] : (tensor, memref>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir b/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir index a0f5dc68..6f74d021 100644 --- a/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir +++ b/test/Conversion/StructuredToMemref/masked_ldst_2d.mlir @@ -64,40 +64,42 @@ module { // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 // CHECK-DAG: [[CST_130_:%.+]] = arith.constant 130 : index -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[CST_259_:%.+]] = arith.constant 259 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 // CHECK-DAG: [[CST_3074_:%.+]] = arith.constant 3074 : index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_3074_]]{{.}}, sizes: [128, 256], strides: [1, [[CST_1024_]]{{.}} : memref<*xbf16> to memref<128x256xbf16, strided<[1, ?], offset: ?>> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_3074_]]{{.}}, sizes: [128, 256], strides: [1, [[CST_1024_]]{{.}} : memref<*xbf16> to memref<128x256xbf16, strided<[1, ?], offset: ?>> // CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_1_:%.+]] = arith.minsi [[VAR_0_]], [[CST_130_]] : index -// CHECK-DAG: [[VAR_2_:%.+]] = arith.subi [[VAR_1_]], [[CST_2_]] : index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_3_]], [[CST_259_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[CST_3_]] : index -// CHECK-DAG: [[VAR_6_:%.+]] = arith.minsi [[VAR_2_]], [[CST_128_]] : index +// CHECK: [[VAR_2_:%.+]] = arith.maxsi [[VAR_1_]], [[CST_2_]] : index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.subi [[VAR_2_]], [[CST_2_]] : index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_5_:%.+]] = arith.minsi [[VAR_4_]], [[CST_259_]] : index +// CHECK: [[VAR_6_:%.+]] = arith.maxsi [[VAR_5_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_6_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.minsi [[VAR_3_]], [[CST_128_]] : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = arith.minsi [[VAR_5_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_9_:%.+]] = arith.minsi [[VAR_7_]], [[CST_256_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x256xbf16> -// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpi slt, [[VAR_6_]], [[CST_128_]] : index -// CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_7_]], [[CST_256_]] : index -// CHECK: [[VAR_10_:%.+]] = arith.ori [[VAR_8_]], [[VAR_9_]] : i1 -// CHECK: scf.if [[VAR_10_]] { +// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_128_]] : index +// CHECK: [[VAR_11_:%.+]] = arith.cmpi slt, [[VAR_9_]], [[CST_256_]] : index +// CHECK: [[VAR_12_:%.+]] = arith.ori [[VAR_10_]], [[VAR_11_]] : i1 +// CHECK: scf.if [[VAR_12_]] { // CHECK: linalg.fill ins([[CST_0_]] : bf16) outs([[RES_]] : memref<128x256xbf16>) // CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0, 0] {{.}}[[VAR_6_]], [[VAR_7_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, ?], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_6_]], [[VAR_7_]]{{.}} [1, 1] : memref<128x256xbf16> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, ?], offset: ?>> to memref> +// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_1_]] : memref> to memref> -// CHECK: [[VAR_11_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x256xbf16> -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_11_]][0, 0] {{.}}[[VAR_6_]], [[VAR_7_]]{{.}} [1, 1] : tensor<128x256xbf16> to tensor -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_6_]], [[VAR_7_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, ?], offset: ?>> to memref> +// CHECK: [[VAR_13_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x256xbf16> +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_13_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : tensor<128x256xbf16> to tensor +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_8_]], [[VAR_9_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[1, ?], offset: ?>> to memref> // CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_2_]] : (tensor, memref>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/masked_ldst_sitofp_other.mlir b/test/Conversion/StructuredToMemref/masked_ldst_sitofp_other.mlir index 3f8d2cc8..d0a9ba89 100644 --- a/test/Conversion/StructuredToMemref/masked_ldst_sitofp_other.mlir +++ b/test/Conversion/StructuredToMemref/masked_ldst_sitofp_other.mlir @@ -25,23 +25,24 @@ module { // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_7_dot_000000_:%.+]] = arith.constant 7.000000e+00 : bf16 // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> // CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = arith.minsi [[VAR_0_]], [[CST_128_]] : index +// CHECK: [[VAR_1_:%.+]] = arith.minsi [[VAR_0_]], [[CST_128_]] : index +// CHECK-DAG: [[VAR_2_:%.+]] = arith.maxsi [[VAR_1_]], [[CST_0_]] : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128xbf16> -// CHECK: [[VAR_2_:%.+]] = arith.cmpi slt, [[VAR_1_]], [[CST_128_]] : index -// CHECK: scf.if [[VAR_2_]] { +// CHECK: [[VAR_3_:%.+]] = arith.cmpi slt, [[VAR_2_]], [[CST_128_]] : index +// CHECK: scf.if [[VAR_3_]] { // CHECK: linalg.fill ins([[CST_7_dot_000000_]] : bf16) outs([[RES_]] : memref<128xbf16>) // CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_1_]]{{.}} [1] : memref<128xbf16, strided<[1]>> to memref> -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_1_]]{{.}} [1] : memref<128xbf16> to memref> +// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_2_]]{{.}} [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_2_]]{{.}} [1] : memref<128xbf16> to memref> // CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_1_]] : memref> to memref> -// CHECK: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128xbf16> -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_3_]][0] {{.}}[[VAR_1_]]{{.}} [1] : tensor<128xbf16> to tensor -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0] {{.}}[[VAR_1_]]{{.}} [1] : memref<128xbf16, strided<[1]>> to memref> +// CHECK: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128xbf16> +// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_4_]][0] {{.}}[[VAR_2_]]{{.}} [1] : tensor<128xbf16> to tensor +// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0] {{.}}[[VAR_2_]]{{.}} [1] : memref<128xbf16, strided<[1]>> to memref> // CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_2_]] : (tensor, memref>) -> () // CHECK: return // CHECK: } diff --git a/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir b/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir deleted file mode 100644 index 1a6a5eb4..00000000 --- a/test/Conversion/TritonToLinalg/kernel-01-vector-add.mlir +++ /dev/null @@ -1,77 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { - %c1024_i32 = arith.constant 1024 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - %3 = tt.splat %1 : i32 -> tensor<1024xi32> - %4 = arith.addi %3, %2 : tensor<1024xi32> - %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %9 = tt.load %8, %6 : tensor<1024x!tt.ptr> - %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> - %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %12 = tt.load %11, %6 : tensor<1024x!tt.ptr> - %13 = arith.addf %9, %12 : tensor<1024xf32> - %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> - %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> - tt.store %15, %13, %6 : tensor<1024x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @add_kernel_01234 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index -// CHECK-DAG: [[CST_1024_1_:%.+]] = arith.constant 1024 : i32 -// CHECK: [[VAR_0_:%.+]] = arith.muli [[PARAM_7_]], [[CST_1024_1_]] : i32 -// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [1024], strides: [1]{{.*}} : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[VAR_2_]], [[CST_1024_]] : index -// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_5_:%.+]] = arith.minsi [[VAR_3_]], [[VAR_4_]] : index -// CHECK: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_2_]] : index -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_6_]]{{.}} [1]{{.*}} : memref<1024xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_6_]]{{.}} [1] : memref<1024xf32> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_0 : memref> to memref> -// CHECK-DAG: [[VAR_7_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<1024xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [1024], strides: [1]{{.*}} : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<1024xf32> -// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_9_]], [[CST_1024_]] : index -// CHECK-DAG: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_12_:%.+]] = arith.minsi [[VAR_10_]], [[VAR_11_]] : index -// CHECK: [[VAR_13_:%.+]] = arith.subi [[VAR_12_]], [[VAR_9_]] : index -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_13_]]{{.}} [1]{{.*}} : memref<1024xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_13_]]{{.}} [1] : memref<1024xf32> to memref> -// CHECK: memref.copy [[VAR_subview_3_]], [[VAR_subview_4_]] : memref> to memref> -// CHECK: [[VAR_14_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<1024xf32> -// CHECK: [[VAR_15_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_14_]] : tensor<1024xf32>, tensor<1024xf32>) outs([[VAR_7_]] : tensor<1024xf32>) { -// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_22_:%.+]] = arith.addf [[in_1]], [[in_2]] : f32 -// CHECK: linalg.yield [[VAR_22_]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_16_]]{{.}}, sizes: [1024], strides: [1]{{.*}} : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addi [[VAR_17_]], [[CST_1024_]] : index -// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_20_:%.+]] = arith.minsi [[VAR_18_]], [[VAR_19_]] : index -// CHECK: [[VAR_21_:%.+]] = arith.subi [[VAR_20_]], [[VAR_17_]] : index -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_15_]][0] {{.}}[[VAR_21_]]{{.}} [1] : tensor<1024xf32> to tensor -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_21_]]{{.}} [1]{{.*}} : memref<1024xf32, strided<[1], offset: ?>> to memref> -// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_6_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir b/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir deleted file mode 100644 index a9a8b476..00000000 --- a/test/Conversion/TritonToLinalg/kernel-02-fused-softmax.mlir +++ /dev/null @@ -1,105 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @softmax_kernel_012345(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32) { - %cst = arith.constant 0xFF800000 : f32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %4 = tt.splat %2 : !tt.ptr -> tensor<128x!tt.ptr> - %5 = tt.addptr %4, %3 : tensor<128x!tt.ptr>, tensor<128xi32> - %6 = tt.splat %arg4 : i32 -> tensor<128xi32> - %7 = arith.cmpi slt, %3, %6 : tensor<128xi32> - %8 = tt.splat %cst : f32 -> tensor<128xf32> - %9 = tt.load %5, %7, %8 : tensor<128x!tt.ptr> - %10 = "tt.reduce"(%9) ({ - ^bb0(%arg5: f32, %arg6: f32): - %21 = arith.cmpf ogt, %arg5, %arg6 : f32 - %22 = arith.select %21, %arg5, %arg6 : f32 - tt.reduce.return %22 : f32 - }) {axis = 0 : i32} : (tensor<128xf32>) -> f32 - %11 = tt.splat %10 : f32 -> tensor<128xf32> - %12 = arith.subf %9, %11 : tensor<128xf32> - %13 = math.exp %12 : tensor<128xf32> - %14 = "tt.reduce"(%13) ({ - ^bb0(%arg5: f32, %arg6: f32): - %21 = arith.addf %arg5, %arg6 : f32 - tt.reduce.return %21 : f32 - }) {axis = 0 : i32} : (tensor<128xf32>) -> f32 - %15 = tt.splat %14 : f32 -> tensor<128xf32> - %16 = arith.divf %13, %15 : tensor<128xf32> - %17 = arith.muli %0, %arg3 : i32 - %18 = tt.addptr %arg0, %17 : !tt.ptr, i32 - %19 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr> - %20 = tt.addptr %19, %3 : tensor<128x!tt.ptr>, tensor<128xi32> - tt.store %20, %16, %7 : tensor<128x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @softmax_kernel_012345 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_2_]] : i32 -// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [128], strides: [1]{{.*}} : memref<*xf32> to memref<128xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_128_]] : index -// CHECK-DAG: [[VAR_4_:%.+]] = arith.cmpi slt, [[VAR_3_]], [[CST_128_]] : index -// CHECK: scf.if [[VAR_4_]] { -// CHECK: linalg.fill ins([[CST_0_]] : f32) outs([[RES_]] : memref<128xf32>) -// CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_3_]]{{.}} [1]{{.*}} : memref<128xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_3_]]{{.}} [1] : memref<128xf32> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_1 : memref> to memref> -// CHECK-DAG: [[VAR_5_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = bufferization.alloc_tensor() : tensor -// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_]] into [[VAR_6_]][] : tensor -// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_5_]] : tensor<128xf32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] -// CHECK: ([[in_1:%.+]]: f32, [[init_1:%.+]]: f32) { -// CHECK: [[VAR_19_:%.+]] = arith.maximumf [[in_1]], [[init_1]] : f32 -// CHECK: linalg.yield [[VAR_19_]] : f32 -// CHECK: } -// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor -// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor<128xf32> -// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[VAR_extracted_]] : f32) outs([[VAR_7_]] : tensor<128xf32>) -> tensor<128xf32> -// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_5_]], [[VAR_8_]] : tensor<128xf32>, tensor<128xf32>) outs([[VAR_5_]] : tensor<128xf32>) { -// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_19_1_:%.+]] = arith.subf [[in_1]], [[in_2]] : f32 -// CHECK: linalg.yield [[VAR_19_1_]] : f32 -// CHECK: } -> tensor<128xf32> -// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]] : tensor<128xf32>) outs([[VAR_9_]] : tensor<128xf32>) { -// CHECK: ^bb0([[in_1:%.+]]: f32, [[out_1:%.+]]: f32): -// CHECK: [[VAR_19_2_:%.+]] = math.exp [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_19_2_]] : f32 -// CHECK: } -> tensor<128xf32> -// CHECK: [[VAR_11_:%.+]] = bufferization.alloc_tensor() : tensor -// CHECK: [[VAR_inserted_2_:%.+]] = tensor.insert [[CST_0_dot_000000_]] into [[VAR_11_]][] : tensor -// CHECK: [[VAR_reduced_3_:%.+]] = linalg.reduce ins([[VAR_10_]] : tensor<128xf32>) outs([[VAR_inserted_2_]] : tensor) dimensions = [0] -// CHECK: ([[in_1:%.+]]: f32, [[init_1:%.+]]: f32) { -// CHECK: [[VAR_19_3_:%.+]] = arith.addf [[in_1]], [[init_1]] : f32 -// CHECK: linalg.yield [[VAR_19_3_]] : f32 -// CHECK: } -// CHECK-DAG: [[VAR_extracted_4_:%.+]] = tensor.extract [[VAR_reduced_3_]][] : tensor -// CHECK-DAG: [[VAR_12_:%.+]] = tensor.empty() : tensor<128xf32> -// CHECK: [[VAR_13_:%.+]] = linalg.fill ins([[VAR_extracted_4_]] : f32) outs([[VAR_12_]] : tensor<128xf32>) -> tensor<128xf32> -// CHECK: [[VAR_14_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_10_]], [[VAR_13_]] : tensor<128xf32>, tensor<128xf32>) outs([[VAR_10_]] : tensor<128xf32>) { -// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out_1:%.+]]: f32): -// CHECK: [[VAR_19_4_:%.+]] = arith.divf [[in_1]], [[in_2]] : f32 -// CHECK: linalg.yield [[VAR_19_4_]] : f32 -// CHECK: } -> tensor<128xf32> -// CHECK: [[VAR_15_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_3_]] : i32 -// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_16_]]{{.}}, sizes: [128], strides: [1]{{.*}} : memref<*xf32> to memref<128xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_18_:%.+]] = arith.minsi [[VAR_17_]], [[CST_128_]] : index -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_14_]][0] {{.}}[[VAR_18_]]{{.}} [1] : tensor<128xf32> to tensor -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_18_]]{{.}} [1]{{.*}} : memref<128xf32, strided<[1], offset: ?>> to memref> -// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_6_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir b/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir deleted file mode 100644 index 032adfe6..00000000 --- a/test/Conversion/TritonToLinalg/kernel-03-matrix-multiplication.mlir +++ /dev/null @@ -1,217 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @matmul_kernel_0123456789101112131415(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) { - %c63_i32 = arith.constant 63 : i32 - %c255_i32 = arith.constant 255 : i32 - %c127_i32 = arith.constant 127 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c64_i32 = arith.constant 64 : i32 - %cst = arith.constant 0.000000e+00 : f32 - %c256_i32 = arith.constant 256 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c127_i32 : i32 - %2 = arith.divsi %1, %c128_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.addi %arg5, %c63_i32 : i32 - %6 = arith.divsi %5, %c64_i32 : i32 - %7 = arith.muli %4, %c8_i32 : i32 - %8 = arith.divsi %0, %7 : i32 - %9 = arith.muli %8, %c8_i32 : i32 - %10 = arith.subi %2, %9 : i32 - %11 = arith.cmpi slt, %10, %c8_i32 : i32 - %12 = arith.select %11, %10, %c8_i32 : i32 - %13 = arith.remsi %0, %12 : i32 - %14 = arith.addi %9, %13 : i32 - %15 = arith.remsi %0, %7 : i32 - %16 = arith.divsi %15, %12 : i32 - %17 = arith.muli %14, %c128_i32 : i32 - %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %19 = tt.splat %17 : i32 -> tensor<128xi32> - %20 = arith.addi %19, %18 : tensor<128xi32> - %21 = arith.muli %16, %c256_i32 : i32 - %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %23 = tt.splat %21 : i32 -> tensor<256xi32> - %24 = arith.addi %23, %22 : tensor<256xi32> - %25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %26 = tt.expand_dims %20 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %27 = tt.splat %arg6 : i32 -> tensor<128x1xi32> - %28 = arith.muli %26, %27 : tensor<128x1xi32> - %29 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %30 = tt.splat %arg7 : i32 -> tensor<1x64xi32> - %31 = arith.muli %29, %30 : tensor<1x64xi32> - %32 = tt.broadcast %28 : tensor<128x1xi32> -> tensor<128x64xi32> - %33 = tt.broadcast %31 : tensor<1x64xi32> -> tensor<128x64xi32> - %34 = arith.addi %32, %33 : tensor<128x64xi32> - %35 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> - %36 = tt.addptr %35, %34 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> - %37 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> - %38 = tt.splat %arg8 : i32 -> tensor<64x1xi32> - %39 = arith.muli %37, %38 : tensor<64x1xi32> - %40 = tt.expand_dims %24 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %41 = tt.splat %arg9 : i32 -> tensor<1x256xi32> - %42 = arith.muli %40, %41 : tensor<1x256xi32> - %43 = tt.broadcast %39 : tensor<64x1xi32> -> tensor<64x256xi32> - %44 = tt.broadcast %42 : tensor<1x256xi32> -> tensor<64x256xi32> - %45 = arith.addi %43, %44 : tensor<64x256xi32> - %46 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> - %47 = tt.addptr %46, %45 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> - %48 = tt.splat %cst : f32 -> tensor<128x256xf32> - %49 = arith.muli %arg7, %c64_i32 : i32 - %50 = tt.splat %49 : i32 -> tensor<128x64xi32> - %51 = arith.muli %arg8, %c64_i32 : i32 - %52 = tt.splat %51 : i32 -> tensor<64x256xi32> - %53:3 = scf.for %arg12 = %c0_i32 to %6 step %c1_i32 iter_args(%arg13 = %48, %arg14 = %36, %arg15 = %47) -> (tensor<128x256xf32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr>) : i32 { - %71 = tt.load %arg14 : tensor<128x64x!tt.ptr> - %72 = tt.load %arg15 : tensor<64x256x!tt.ptr> - %73 = tt.dot %71, %72, %48 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> - %74 = arith.addf %arg13, %73 : tensor<128x256xf32> - %75 = tt.addptr %arg14, %50 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> - %76 = tt.addptr %arg15, %52 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> - scf.yield %74, %75, %76 : tensor<128x256xf32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr> - } - %54 = arith.truncf %53#0 : tensor<128x256xf32> to tensor<128x256xbf16> - %55 = tt.splat %arg10 : i32 -> tensor<128x1xi32> - %56 = arith.muli %55, %26 : tensor<128x1xi32> - %57 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> - %58 = tt.addptr %57, %56 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> - %59 = tt.splat %arg11 : i32 -> tensor<1x256xi32> - %60 = arith.muli %59, %40 : tensor<1x256xi32> - %61 = tt.broadcast %58 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> - %62 = tt.broadcast %60 : tensor<1x256xi32> -> tensor<128x256xi32> - %63 = tt.addptr %61, %62 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - %64 = tt.splat %arg3 : i32 -> tensor<128x1xi32> - %65 = arith.cmpi slt, %26, %64 : tensor<128x1xi32> - %66 = tt.splat %arg4 : i32 -> tensor<1x256xi32> - %67 = arith.cmpi slt, %40, %66 : tensor<1x256xi32> - %68 = tt.broadcast %65 : tensor<128x1xi1> -> tensor<128x256xi1> - %69 = tt.broadcast %67 : tensor<1x256xi1> -> tensor<128x256xi1> - %70 = arith.andi %68, %69 : tensor<128x256xi1> - tt.store %63, %54, %70 : tensor<128x256x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @matmul_kernel_0123456789101112131415 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32, [[PARAM_14_:%.+]]: i32, [[PARAM_15_:%.+]]: i32, [[PARAM_16_:%.+]]: i32, [[PARAM_17_:%.+]]: i32) { -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : i32 -// CHECK-DAG: [[CST_128_1_:%.+]] = arith.constant 128 : i32 -// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 -// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : i32 -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_127_:%.+]] = arith.constant 127 : i32 -// CHECK-DAG: [[CST_255_:%.+]] = arith.constant 255 : i32 -// CHECK-DAG: [[CST_63_:%.+]] = arith.constant 63 : i32 -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x256xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<128x256xf32>) -> tensor<128x256xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[PARAM_3_]], [[CST_127_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = arith.divsi [[VAR_2_]], [[CST_128_1_]] : i32 -// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[PARAM_4_]], [[CST_255_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_5_:%.+]] = arith.divsi [[VAR_4_]], [[CST_256_1_]] : i32 -// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[PARAM_5_]], [[CST_63_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = arith.divsi [[VAR_6_]], [[CST_64_]] : i32 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.muli [[VAR_5_]], [[CST_8_]] : i32 -// CHECK: [[VAR_9_:%.+]] = arith.divsi [[PARAM_15_]], [[VAR_8_]] : i32 -// CHECK: [[VAR_10_:%.+]] = arith.muli [[VAR_9_]], [[CST_8_]] : i32 -// CHECK: [[VAR_11_:%.+]] = arith.subi [[VAR_3_]], [[VAR_10_]] : i32 -// CHECK: [[VAR_12_:%.+]] = arith.minsi [[VAR_11_]], [[CST_8_]] : i32 -// CHECK: [[VAR_13_:%.+]] = arith.remsi [[PARAM_15_]], [[VAR_12_]] : i32 -// CHECK-DAG: [[VAR_14_:%.+]] = arith.addi [[VAR_10_]], [[VAR_13_]] : i32 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.remsi [[PARAM_15_]], [[VAR_8_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_16_:%.+]] = arith.divsi [[VAR_15_]], [[VAR_12_]] : i32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_14_]], [[CST_128_1_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_18_:%.+]] = arith.muli [[VAR_16_]], [[CST_256_1_]] : i32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[VAR_17_]] : i32 to index -// CHECK-DAG: [[VAR_20_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_21_:%.+]] = arith.muli [[VAR_19_]], [[VAR_20_]] : index -// CHECK-DAG: [[VAR_22_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[PARAM_8_]] : i32 to index -// CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[VAR_18_]] : i32 to index -// CHECK-DAG: [[VAR_25_:%.+]] = arith.index_cast [[PARAM_9_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.muli [[VAR_24_]], [[VAR_25_]] : index -// CHECK-DAG: [[VAR_27_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_]] : i32 -// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_]] : i32 -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_21_]]{{.}}, sizes: [128, 64], strides: {{.}}[[VAR_20_]], [[VAR_22_]]{{.}} : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_26_]]{{.}}, sizes: [64, 256], strides: {{.}}[[VAR_23_]], [[VAR_25_]]{{.}} : memref<*xbf16> to memref<64x256xbf16, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_29_:%.+]]:7 = scf.for [[VAR_arg18_:%.+]] = [[CST_0_1_]] to [[VAR_7_]] step [[CST_1_]] iter_args([[VAR_arg19_:%.+]] = [[VAR_1_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg21_:%.+]] = [[VAR_reinterpret_cast_]]_0, [[VAR_arg22_:%.+]] = [[VAR_21_]], [[VAR_arg23_:%.+]] = [[CST_0_]], [[VAR_arg24_:%.+]] = [[VAR_26_]], [[VAR_arg25_:%.+]] = [[CST_0_]]) -> (tensor<128x256xf32>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<64x256xbf16, strided<[?, ?], offset: ?>>, index, index, index, index) : i32 { -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy [[VAR_arg20_]], [[RES_]] : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16> -// CHECK-DAG: [[VAR_51_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<64x256xbf16> -// CHECK: memref.copy [[VAR_arg21_]], [[RES_1_]] : memref<64x256xbf16, strided<[?, ?], offset: ?>> to memref<64x256xbf16> -// CHECK-DAG: [[VAR_52_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<64x256xbf16> -// CHECK-DAG: [[VAR_53_:%.+]] = tensor.empty() : tensor<128x256xf32> -// CHECK: [[VAR_54_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_53_]] : tensor<128x256xf32>) -> tensor<128x256xf32> -// CHECK: [[VAR_55_:%.+]] = linalg.matmul ins([[VAR_51_]], [[VAR_52_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_54_]] : tensor<128x256xf32>) -> tensor<128x256xf32> -// CHECK: [[VAR_56_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_55_]], [[VAR_1_]] : tensor<128x256xf32>, tensor<128x256xf32>) outs([[VAR_55_]] : tensor<128x256xf32>) { -// CHECK: ^bb0([[in_:.+]]: f32, [[in_1:.+]]: f32, [[out_:.+]]: f32): -// CHECK: [[VAR_64_:%.+]] = arith.addf [[in_]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_64_]] : f32 -// CHECK: } -> tensor<128x256xf32> -// CHECK: [[VAR_57_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg19_]], [[VAR_56_]] : tensor<128x256xf32>, tensor<128x256xf32>) outs([[VAR_arg19_]] : tensor<128x256xf32>) { -// CHECK: ^bb0([[in_]]: f32, [[in_1:.+]]: f32, [[out_:.+]]: f32): -// CHECK: [[VAR_64_1_:%.+]] = arith.addf [[in_]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_64_1_]] : f32 -// CHECK: } -> tensor<128x256xf32> -// CHECK: [[VAR_58_:%.+]] = arith.index_cast [[VAR_27_]] : i32 to index -// CHECK: [[VAR_59_:%.+]] = arith.addi [[VAR_arg22_]], [[VAR_58_]] : index -// CHECK: [[VAR_60_:%.+]] = arith.addi [[VAR_59_]], [[VAR_arg23_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_3_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_60_]]{{.}}, sizes: [128, 64], strides: {{.}}[[VAR_20_]], [[VAR_22_]]{{.}} : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_61_:%.+]] = arith.index_cast [[VAR_28_]] : i32 to index -// CHECK: [[VAR_62_:%.+]] = arith.addi [[VAR_arg24_]], [[VAR_61_]] : index -// CHECK: [[VAR_63_:%.+]] = arith.addi [[VAR_62_]], [[VAR_arg25_]] : index -// CHECK: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_63_]]{{.}}, sizes: [64, 256], strides: {{.}}[[VAR_23_]], [[VAR_25_]]{{.}} : memref<*xbf16> to memref<64x256xbf16, strided<[?, ?], offset: ?>> -// CHECK: scf.yield [[VAR_57_]], [[VAR_reinterpret_cast_3_]], [[VAR_reinterpret_cast_4_]], [[VAR_60_]], [[CST_0_]], [[VAR_63_]], [[CST_0_]] : tensor<128x256xf32>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<64x256xbf16, strided<[?, ?], offset: ?>>, index, index, index, index -// CHECK: } -// CHECK: [[VAR_30_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: [[VAR_31_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_29_]]#0 : tensor<128x256xf32>) outs([[VAR_30_]] : tensor<128x256xbf16>) { -// CHECK: ^bb0([[in_:.+]]: f32, [[out_:.+]]: bf16): -// CHECK: [[VAR_51_1_:%.+]] = arith.truncf [[in_]] : f32 to bf16 -// CHECK: linalg.yield [[VAR_51_1_]] : bf16 -// CHECK: } -> tensor<128x256xbf16> -// CHECK-DAG: [[VAR_32_:%.+]] = arith.index_cast [[PARAM_10_]] : i32 to index -// CHECK-DAG: [[VAR_33_:%.+]] = arith.index_cast [[VAR_17_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_34_:%.+]] = arith.muli [[VAR_33_]], [[VAR_32_]] : index -// CHECK-DAG: [[VAR_35_:%.+]] = arith.index_cast [[PARAM_11_]] : i32 to index -// CHECK-DAG: [[VAR_36_:%.+]] = arith.index_cast [[VAR_18_]] : i32 to index -// CHECK: [[VAR_37_:%.+]] = arith.muli [[VAR_36_]], [[VAR_35_]] : index -// CHECK: [[VAR_38_:%.+]] = arith.addi [[VAR_34_]], [[VAR_37_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_38_]]{{.}}, sizes: [128, 256], strides: {{.}}[[VAR_32_]], [[VAR_35_]]{{.}} : memref<*xbf16> to memref<128x256xbf16, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.index_cast [[VAR_17_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.addi [[VAR_39_]], [[CST_128_]] : index -// CHECK-DAG: [[VAR_41_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_42_:%.+]] = arith.minsi [[VAR_40_]], [[VAR_41_]] : index -// CHECK-DAG: [[VAR_43_:%.+]] = arith.subi [[VAR_42_]], [[VAR_39_]] : index -// CHECK-DAG: [[VAR_44_:%.+]] = arith.index_cast [[VAR_18_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_45_:%.+]] = arith.addi [[VAR_44_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_46_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_47_:%.+]] = arith.minsi [[VAR_45_]], [[VAR_46_]] : index -// CHECK-DAG: [[VAR_48_:%.+]] = arith.subi [[VAR_47_]], [[VAR_44_]] : index -// CHECK-DAG: [[VAR_49_:%.+]] = arith.minsi [[VAR_43_]], [[CST_128_]] : index -// CHECK: [[VAR_50_:%.+]] = arith.minsi [[VAR_48_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_31_]][0, 0] {{.}}[[VAR_49_]], [[VAR_50_]]{{.}} [1, 1] : tensor<128x256xbf16> to tensor -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_49_]], [[VAR_50_]]{{.}} [1, 1] : memref<128x256xbf16, strided<[?, ?], offset: ?>> to memref> -// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir b/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir deleted file mode 100644 index e1ff19df..00000000 --- a/test/Conversion/TritonToLinalg/kernel-05-layer-norm-dwdb.mlir +++ /dev/null @@ -1,189 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @_layer_norm_bwd_dwdb_0123456(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: i32, %arg5: i32) { - %c0_i32 = arith.constant 0 : i32 - %c256_i32 = arith.constant 256 : i32 - %cst = arith.constant 0.000000e+00 : f32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c256_i32 : i32 - %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %3 = tt.splat %1 : i32 -> tensor<256xi32> - %4 = arith.addi %3, %2 : tensor<256xi32> - %5 = tt.splat %cst : f32 -> tensor<256x256xf32> - %6 = tt.splat %arg4 : i32 -> tensor<256x1xi32> - %7 = tt.expand_dims %4 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %8 = tt.splat %arg5 : i32 -> tensor<1x256xi32> - %9 = arith.cmpi slt, %7, %8 : tensor<1x256xi32> - %10 = tt.broadcast %9 : tensor<1x256xi1> -> tensor<256x256xi1> - %11 = tt.splat %arg5 : i32 -> tensor<256x1xi32> - %12 = tt.broadcast %7 : tensor<1x256xi32> -> tensor<256x256xi32> - %13 = tt.splat %arg0 : !tt.ptr -> tensor<256x256x!tt.ptr> - %14 = tt.splat %arg1 : !tt.ptr -> tensor<256x256x!tt.ptr> - %15:2 = scf.for %arg6 = %c0_i32 to %arg4 step %c256_i32 iter_args(%arg7 = %5, %arg8 = %5) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { - %24 = tt.splat %arg6 : i32 -> tensor<256xi32> - %25 = arith.addi %24, %2 : tensor<256xi32> - %26 = tt.expand_dims %25 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - %27 = arith.cmpi slt, %26, %6 : tensor<256x1xi32> - %28 = tt.broadcast %27 : tensor<256x1xi1> -> tensor<256x256xi1> - %29 = arith.andi %28, %10 : tensor<256x256xi1> - %30 = arith.muli %26, %11 : tensor<256x1xi32> - %31 = tt.broadcast %30 : tensor<256x1xi32> -> tensor<256x256xi32> - %32 = arith.addi %31, %12 : tensor<256x256xi32> - %33 = tt.addptr %13, %32 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> - %34 = tt.load %33, %29, %5 : tensor<256x256x!tt.ptr> - %35 = arith.addf %arg7, %34 : tensor<256x256xf32> - %36 = tt.addptr %14, %32 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> - %37 = tt.load %36, %29, %5 : tensor<256x256x!tt.ptr> - %38 = arith.addf %arg8, %37 : tensor<256x256xf32> - scf.yield %35, %38 : tensor<256x256xf32>, tensor<256x256xf32> - } - %16 = "tt.reduce"(%15#0) ({ - ^bb0(%arg6: f32, %arg7: f32): - %24 = arith.addf %arg6, %arg7 : f32 - tt.reduce.return %24 : f32 - }) {axis = 0 : i32} : (tensor<256x256xf32>) -> tensor<256xf32> - %17 = "tt.reduce"(%15#1) ({ - ^bb0(%arg6: f32, %arg7: f32): - %24 = arith.addf %arg6, %arg7 : f32 - tt.reduce.return %24 : f32 - }) {axis = 0 : i32} : (tensor<256x256xf32>) -> tensor<256xf32> - %18 = tt.splat %arg5 : i32 -> tensor<256xi32> - %19 = arith.cmpi slt, %4, %18 : tensor<256xi32> - %20 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr> - %21 = tt.addptr %20, %4 : tensor<256x!tt.ptr>, tensor<256xi32> - tt.store %21, %16, %19 : tensor<256x!tt.ptr> - %22 = tt.splat %arg3 : !tt.ptr -> tensor<256x!tt.ptr> - %23 = tt.addptr %22, %4 : tensor<256x!tt.ptr>, tensor<256xi32> - tt.store %23, %17, %19 : tensor<256x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @_layer_norm_bwd_dwdb_0123456 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32) { -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<256x256xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<256x256xf32>) -> tensor<256x256xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_9_]], [[CST_256_1_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg9_:%.+]] = [[CST_0_]] to [[PARAM_4_]] step [[CST_256_1_]] iter_args([[VAR_arg10_:%.+]] = [[VAR_1_]], [[VAR_arg11_:%.+]] = [[VAR_1_]]) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { -// CHECK-DAG: [[VAR_20_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index -// CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[VAR_20_]], [[VAR_21_]] : index -// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK: [[VAR_24_:%.+]] = arith.addi [[VAR_22_]], [[VAR_23_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_24_]]{{.}}, sizes: [256, 256], strides: {{.}}[[VAR_21_]], 1] : memref<*xf32> to memref<256x256xf32, strided<[?, 1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<256x256xf32> -// CHECK-DAG: [[VAR_25_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.addi [[VAR_25_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_27_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_28_:%.+]] = arith.minsi [[VAR_26_]], [[VAR_27_]] : index -// CHECK-DAG: [[VAR_29_:%.+]] = arith.subi [[VAR_28_]], [[VAR_25_]] : index -// CHECK-DAG: [[VAR_30_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_31_:%.+]] = arith.addi [[VAR_30_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_32_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_33_:%.+]] = arith.minsi [[VAR_31_]], [[VAR_32_]] : index -// CHECK-DAG: [[VAR_34_:%.+]] = arith.subi [[VAR_33_]], [[VAR_30_]] : index -// CHECK-DAG: [[VAR_35_:%.+]] = arith.minsi [[VAR_29_]], [[CST_256_]] : index -// CHECK: [[VAR_36_:%.+]] = arith.minsi [[VAR_34_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_37_:%.+]] = arith.cmpi slt, [[VAR_35_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_38_:%.+]] = arith.cmpi slt, [[VAR_36_]], [[CST_256_]] : index -// CHECK: [[VAR_39_:%.+]] = arith.ori [[VAR_37_]], [[VAR_38_]] : i1 -// CHECK: scf.if [[VAR_39_]] { -// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_]] : memref<256x256xf32>) -// CHECK: } -// CHECK-DAG: [[VAR_subview_5_:%.+]] = memref.subview [[VAR_reinterpret_cast_4_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_35_]], [[VAR_36_]]{{.}} [1, 1] : memref<256x256xf32> to memref> -// CHECK: memref.copy [[VAR_subview_5_]], [[VAR_subview_6_]] : memref> to memref> -// CHECK: [[VAR_40_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256x256xf32> -// CHECK: [[VAR_41_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg10_]], [[VAR_40_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg10_]] : tensor<256x256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_64_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_64_]] : f32 -// CHECK: } -> tensor<256x256xf32> -// CHECK-DAG: [[VAR_42_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index -// CHECK-DAG: [[VAR_43_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_44_:%.+]] = arith.muli [[VAR_42_]], [[VAR_43_]] : index -// CHECK-DAG: [[VAR_45_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK: [[VAR_46_:%.+]] = arith.addi [[VAR_44_]], [[VAR_45_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_46_]]{{.}}, sizes: [256, 256], strides: {{.}}[[VAR_43_]], 1] : memref<*xf32> to memref<256x256xf32, strided<[?, 1], offset: ?>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<256x256xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_48_:%.+]] = arith.addi [[VAR_47_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_49_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_50_:%.+]] = arith.minsi [[VAR_48_]], [[VAR_49_]] : index -// CHECK-DAG: [[VAR_51_:%.+]] = arith.subi [[VAR_50_]], [[VAR_47_]] : index -// CHECK-DAG: [[VAR_52_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_53_:%.+]] = arith.addi [[VAR_52_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_54_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_55_:%.+]] = arith.minsi [[VAR_53_]], [[VAR_54_]] : index -// CHECK-DAG: [[VAR_56_:%.+]] = arith.subi [[VAR_55_]], [[VAR_52_]] : index -// CHECK-DAG: [[VAR_57_:%.+]] = arith.minsi [[VAR_51_]], [[CST_256_]] : index -// CHECK: [[VAR_58_:%.+]] = arith.minsi [[VAR_56_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_59_:%.+]] = arith.cmpi slt, [[VAR_57_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_60_:%.+]] = arith.cmpi slt, [[VAR_58_]], [[CST_256_]] : index -// CHECK: [[VAR_61_:%.+]] = arith.ori [[VAR_59_]], [[VAR_60_]] : i1 -// CHECK: scf.if [[VAR_61_]] { -// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_1_]] : memref<256x256xf32>) -// CHECK: } -// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0, 0] {{.}}[[VAR_57_]], [[VAR_58_]]{{.}} [1, 1] : memref<256x256xf32, strided<[?, 1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_1_]][0, 0] {{.}}[[VAR_57_]], [[VAR_58_]]{{.}} [1, 1] : memref<256x256xf32> to memref> -// CHECK: memref.copy [[VAR_subview_9_]], [[VAR_subview_10_]] : memref> to memref> -// CHECK: [[VAR_62_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256x256xf32> -// CHECK: [[VAR_63_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg11_]], [[VAR_62_]] : tensor<256x256xf32>, tensor<256x256xf32>) outs([[VAR_arg11_]] : tensor<256x256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_64_1_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_64_1_]] : f32 -// CHECK: } -> tensor<256x256xf32> -// CHECK: scf.yield [[VAR_41_]], [[VAR_63_]] : tensor<256x256xf32>, tensor<256x256xf32> -// CHECK: } -// CHECK: [[VAR_4_:%.+]] = tensor.empty() : tensor<256xf32> -// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_4_]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_3_]]#0 : tensor<256x256xf32>) outs([[VAR_5_]] : tensor<256xf32>) dimensions = [0] -// CHECK: ([[in_0:.+]]: f32, [[in_1:.+]]: f32) { -// CHECK: [[VAR_20_1_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_20_1_]] : f32 -// CHECK: } -// CHECK: [[VAR_6_:%.+]] = tensor.empty() : tensor<256xf32> -// CHECK: [[VAR_7_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_6_]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK: [[VAR_reduced_0_:%.+]] = linalg.reduce ins([[VAR_3_]]#1 : tensor<256x256xf32>) outs([[VAR_7_]] : tensor<256xf32>) dimensions = [0] -// CHECK: ([[in_0:.+]]: f32, [[in_1:.+]]: f32) { -// CHECK: [[VAR_20_2_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_20_2_]] : f32 -// CHECK: } -// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_9_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_12_:%.+]] = arith.minsi [[VAR_10_]], [[VAR_11_]] : index -// CHECK: [[VAR_13_:%.+]] = arith.subi [[VAR_12_]], [[VAR_9_]] : index -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_reduced_]][0] {{.}}[[VAR_13_]]{{.}} [1] : tensor<256xf32> to tensor -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_13_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_]] -// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_3_]] to offset: {{.}}[[VAR_14_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_15_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_15_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_18_:%.+]] = arith.minsi [[VAR_16_]], [[VAR_17_]] : index -// CHECK: [[VAR_19_:%.+]] = arith.subi [[VAR_18_]], [[VAR_15_]] : index -// CHECK-DAG: [[VAR_extracted_slice_2_:%.+]] = tensor.extract_slice [[VAR_reduced_0_]][0] {{.}}[[VAR_19_]]{{.}} [1] : tensor<256xf32> to tensor -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0] {{.}}[[VAR_19_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_2_]] in writable [[VAR_subview_3_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir b/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir deleted file mode 100644 index 76cc0da3..00000000 --- a/test/Conversion/TritonToLinalg/kernel-05-layer-norm-fwd.mlir +++ /dev/null @@ -1,313 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @_layer_norm_fwd_fused_0123456789(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: !tt.ptr, %arg5: !tt.ptr, %arg6: i32, %arg7: i32, %arg8: f32) { - %c256_i32 = arith.constant 256 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant 1.000000e+00 : f32 - %cst_0 = arith.constant 0.000000e+00 : f32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg6 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - %4 = tt.splat %cst_0 : f32 -> tensor<256xf32> - %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %6 = tt.splat %arg7 : i32 -> tensor<256xi32> - %7 = tt.splat %3 : !tt.ptr -> tensor<256x!tt.ptr> - %8 = scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 iter_args(%arg10 = %4) -> (tensor<256xf32>) : i32 { - %32 = tt.splat %arg9 : i32 -> tensor<256xi32> - %33 = arith.addi %32, %5 : tensor<256xi32> - %34 = arith.cmpi slt, %33, %6 : tensor<256xi32> - %35 = tt.addptr %7, %33 : tensor<256x!tt.ptr>, tensor<256xi32> - %36 = tt.load %35, %34, %4 : tensor<256x!tt.ptr> - %37 = arith.addf %arg10, %36 : tensor<256xf32> - scf.yield %37 : tensor<256xf32> - } - %9 = "tt.reduce"(%8) ({ - ^bb0(%arg9: f32, %arg10: f32): - %32 = arith.addf %arg9, %arg10 : f32 - tt.reduce.return %32 : f32 - }) {axis = 0 : i32} : (tensor<256xf32>) -> f32 - %10 = arith.sitofp %arg7 : i32 to f32 - %11 = arith.divf %9, %10 : f32 - %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %13 = tt.splat %arg7 : i32 -> tensor<256xi32> - %14 = tt.splat %3 : !tt.ptr -> tensor<256x!tt.ptr> - %15 = tt.splat %11 : f32 -> tensor<256xf32> - %16 = scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 iter_args(%arg10 = %4) -> (tensor<256xf32>) : i32 { - %32 = tt.splat %arg9 : i32 -> tensor<256xi32> - %33 = arith.addi %32, %12 : tensor<256xi32> - %34 = arith.cmpi slt, %33, %13 : tensor<256xi32> - %35 = tt.addptr %14, %33 : tensor<256x!tt.ptr>, tensor<256xi32> - %36 = tt.load %35, %34, %4 : tensor<256x!tt.ptr> - %37 = arith.subf %36, %15 : tensor<256xf32> - %38 = arith.select %34, %37, %4 : tensor<256xi1>, tensor<256xf32> - %39 = arith.mulf %38, %38 : tensor<256xf32> - %40 = arith.addf %arg10, %39 : tensor<256xf32> - scf.yield %40 : tensor<256xf32> - } - %17 = "tt.reduce"(%16) ({ - ^bb0(%arg9: f32, %arg10: f32): - %32 = arith.addf %arg9, %arg10 : f32 - tt.reduce.return %32 : f32 - }) {axis = 0 : i32} : (tensor<256xf32>) -> f32 - %18 = arith.divf %17, %10 : f32 - %19 = arith.addf %18, %arg8 : f32 - %20 = math.sqrt %19 : f32 - %21 = arith.divf %cst, %20 : f32 - %22 = tt.addptr %arg4, %0 : !tt.ptr, i32 - tt.store %22, %11 : !tt.ptr - %23 = tt.addptr %arg5, %0 : !tt.ptr, i32 - tt.store %23, %21 : !tt.ptr - %24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %25 = tt.splat %arg7 : i32 -> tensor<256xi32> - %26 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr> - %27 = tt.splat %arg3 : !tt.ptr -> tensor<256x!tt.ptr> - %28 = tt.splat %3 : !tt.ptr -> tensor<256x!tt.ptr> - %29 = tt.splat %11 : f32 -> tensor<256xf32> - %30 = tt.splat %21 : f32 -> tensor<256xf32> - %31 = tt.splat %2 : !tt.ptr -> tensor<256x!tt.ptr> - scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 : i32 { - %32 = tt.splat %arg9 : i32 -> tensor<256xi32> - %33 = arith.addi %32, %24 : tensor<256xi32> - %34 = arith.cmpi slt, %33, %25 : tensor<256xi32> - %35 = tt.addptr %26, %33 : tensor<256x!tt.ptr>, tensor<256xi32> - %36 = tt.load %35, %34 : tensor<256x!tt.ptr> - %37 = tt.addptr %27, %33 : tensor<256x!tt.ptr>, tensor<256xi32> - %38 = tt.load %37, %34 : tensor<256x!tt.ptr> - %39 = tt.addptr %28, %33 : tensor<256x!tt.ptr>, tensor<256xi32> - %40 = tt.load %39, %34, %4 : tensor<256x!tt.ptr> - %41 = arith.subf %40, %29 : tensor<256xf32> - %42 = arith.mulf %41, %30 : tensor<256xf32> - %43 = arith.mulf %42, %36 : tensor<256xf32> - %44 = arith.addf %43, %38 : tensor<256xf32> - %45 = tt.addptr %31, %33 : tensor<256x!tt.ptr>, tensor<256xi32> - tt.store %45, %44, %34 : tensor<256x!tt.ptr> - } - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @_layer_norm_fwd_fused_0123456789 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: memref<*xf32>, [[PARAM_5_:%.+]]: memref<*xf32>, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: f32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32, [[PARAM_14_:%.+]]: i32) { -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<256xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_12_]], [[PARAM_6_]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = scf.for [[VAR_arg12_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] iter_args([[VAR_arg13_:%.+]] = [[VAR_1_]]) -> (tensor<256xf32>) : i32 { -// CHECK-DAG: [[VAR_25_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-DAG: [[VAR_26_:%.+]] = arith.index_cast [[VAR_arg12_]] : i32 to index -// CHECK: [[VAR_27_:%.+]] = arith.addi [[VAR_25_]], [[VAR_26_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<256xf32> -// CHECK-DAG: [[VAR_28_:%.+]] = arith.index_cast [[VAR_arg12_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_29_:%.+]] = arith.addi [[VAR_28_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_30_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK: [[VAR_31_:%.+]] = arith.minsi [[VAR_29_]], [[VAR_30_]] : index -// CHECK: [[VAR_32_:%.+]] = arith.subi [[VAR_31_]], [[VAR_28_]] : index -// CHECK-DAG: [[VAR_33_:%.+]] = arith.cmpi slt, [[VAR_32_]], [[CST_256_]] : index -// CHECK: scf.if [[VAR_33_]] { -// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_]] : memref<256xf32>) -// CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_]][0] {{.}}[[VAR_32_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_32_]]{{.}} [1] : memref<256xf32> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_6 : memref> to memref> -// CHECK: [[VAR_34_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<256xf32> -// CHECK: [[VAR_35_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg13_]], [[VAR_34_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_arg13_]] : tensor<256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_36_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_36_]] : f32 -// CHECK: } -> tensor<256xf32> -// CHECK: scf.yield [[VAR_35_]] : tensor<256xf32> -// CHECK: } -// CHECK: [[VAR_4_:%.+]] = bufferization.alloc_tensor() : tensor -// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_0_dot_000000_]] into [[VAR_4_]][] : tensor -// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[VAR_3_]] : tensor<256xf32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] -// CHECK: ([[in_0:.+]]: f32, [[in_1:.+]]: f32) { -// CHECK: [[VAR_25_1_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_25_1_]] : f32 -// CHECK: } -// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor -// CHECK-DAG: [[VAR_5_:%.+]] = arith.sitofp [[PARAM_7_]] : i32 to f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_6_:%.+]] = arith.divf [[VAR_extracted_]], [[VAR_5_]] : f32 -// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor<256xi32> -// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_7_]] : tensor<256xi32>) { -// CHECK: ^bb0([[out:.+]]: i32): -// CHECK: [[VAR_25_2_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_26_1_:%.+]] = arith.index_cast [[VAR_25_2_]] : index to i32 -// CHECK: linalg.yield [[VAR_26_1_]] : i32 -// CHECK: } -> tensor<256xi32> -// CHECK: [[VAR_9_:%.+]] = tensor.empty() : tensor<256xi32> -// CHECK-DAG: [[VAR_10_:%.+]] = linalg.fill ins([[PARAM_7_]] : i32) outs([[VAR_9_]] : tensor<256xi32>) -> tensor<256xi32> -// CHECK-DAG: [[VAR_11_:%.+]] = tensor.empty() : tensor<256xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = linalg.fill ins([[VAR_6_]] : f32) outs([[VAR_11_]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = scf.for [[VAR_arg12_1_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] iter_args([[VAR_arg13_1_:%.+]] = [[VAR_1_]]) -> (tensor<256xf32>) : i32 { -// CHECK-DAG: [[VAR_25_3_:%.+]] = tensor.empty() : tensor<256xi32> -// CHECK: [[VAR_26_2_:%.+]] = linalg.fill ins([[VAR_arg12_1_]] : i32) outs([[VAR_25_3_]] : tensor<256xi32>) -> tensor<256xi32> -// CHECK: [[VAR_27_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_26_2_]], [[VAR_8_]] : tensor<256xi32>, tensor<256xi32>) outs([[VAR_26_2_]] : tensor<256xi32>) { -// CHECK: ^bb0([[in_0:.+]]: i32, [[in_1:.+]]: i32, [[out:.+]]: i32): -// CHECK: [[VAR_44_:%.+]] = arith.addi [[in_0]], [[in_1]] : i32 -// CHECK: linalg.yield [[VAR_44_]] : i32 -// CHECK: } -> tensor<256xi32> -// CHECK: [[VAR_28_1_:%.+]] = tensor.empty() : tensor<256xi1> -// CHECK: [[VAR_29_1_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_27_1_]], [[VAR_10_]] : tensor<256xi32>, tensor<256xi32>) outs([[VAR_28_1_]] : tensor<256xi1>) { -// CHECK: ^bb0([[in_0:.+]]: i32, [[in_1:.+]]: i32, [[out:.+]]: i1): -// CHECK: [[VAR_44_1_:%.+]] = arith.cmpi slt, [[in_0]], [[in_1]] : i32 -// CHECK: linalg.yield [[VAR_44_1_]] : i1 -// CHECK: } -> tensor<256xi1> -// CHECK-DAG: [[VAR_30_1_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-DAG: [[VAR_31_1_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK: [[VAR_32_1_:%.+]] = arith.addi [[VAR_30_1_]], [[VAR_31_1_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_5_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_32_1_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<256xf32> -// CHECK-DAG: [[VAR_33_1_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_34_1_:%.+]] = arith.addi [[VAR_33_1_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_35_1_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK: [[VAR_36_1_:%.+]] = arith.minsi [[VAR_34_1_]], [[VAR_35_1_]] : index -// CHECK: [[VAR_37_:%.+]] = arith.subi [[VAR_36_1_]], [[VAR_33_1_]] : index -// CHECK-DAG: [[VAR_38_:%.+]] = arith.cmpi slt, [[VAR_37_]], [[CST_256_]] : index -// CHECK: scf.if [[VAR_38_]] { -// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_1_]] : memref<256xf32>) -// CHECK: } -// CHECK-DAG: [[VAR_subview_1_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_1_]][0] {{.}}[[VAR_37_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_1_:%.+]] = memref.subview [[RES_1_]][0] {{.}}[[VAR_37_]]{{.}} [1] : memref<256xf32> to memref> -// CHECK: memref.copy [[VAR_subview_1_]], [[VAR_subview_1_]]_6 : memref> to memref> -// CHECK: [[VAR_39_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256xf32> -// CHECK: [[VAR_40_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_39_]], [[VAR_12_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_39_]] : tensor<256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_44_2_:%.+]] = arith.subf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_44_2_]] : f32 -// CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_41_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_29_1_]], [[VAR_40_]], [[VAR_1_]] : tensor<256xi1>, tensor<256xf32>, tensor<256xf32>) outs([[VAR_40_]] : tensor<256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: i1, [[in_1:.+]]: f32, [[in_2:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_44_3_:%.+]] = arith.select [[in_0]], [[in_1]], [[in_2]] : f32 -// CHECK: linalg.yield [[VAR_44_3_]] : f32 -// CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_42_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_41_]], [[VAR_41_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_41_]] : tensor<256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_44_4_:%.+]] = arith.mulf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_44_4_]] : f32 -// CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_43_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg13_1_]], [[VAR_42_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_arg13_1_]] : tensor<256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_44_5_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_44_5_]] : f32 -// CHECK: } -> tensor<256xf32> -// CHECK: scf.yield [[VAR_43_]] : tensor<256xf32> -// CHECK: } -// CHECK: [[VAR_14_:%.+]] = bufferization.alloc_tensor() : tensor -// CHECK: [[VAR_inserted_1_:%.+]] = tensor.insert [[CST_0_dot_000000_]] into [[VAR_14_]][] : tensor -// CHECK: [[VAR_reduced_2_:%.+]] = linalg.reduce ins([[VAR_13_]] : tensor<256xf32>) outs([[VAR_inserted_1_]] : tensor) dimensions = [0] -// CHECK: ([[in_0:.+]]: f32, [[in_1:.+]]: f32) { -// CHECK: [[VAR_25_4_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_25_4_]] : f32 -// CHECK: } -// CHECK: [[VAR_extracted_3_:%.+]] = tensor.extract [[VAR_reduced_2_]][] : tensor -// CHECK: [[VAR_15_:%.+]] = arith.divf [[VAR_extracted_3_]], [[VAR_5_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.addf [[VAR_15_]], [[PARAM_8_]] : f32 -// CHECK: [[VAR_17_:%.+]] = math.sqrt [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[VAR_17_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[PARAM_12_]] : i32 to index -// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_4_]] to offset: {{.}}[[VAR_19_]]{{.}}, sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>> -// CHECK: affine.store [[VAR_6_]], [[VAR_reinterpret_cast_]][0] : memref<1xf32, strided<[1], offset: ?>> -// CHECK: [[VAR_20_:%.+]] = arith.index_cast [[PARAM_12_]] : i32 to index -// CHECK: [[VAR_reinterpret_cast_4_:%.+]] = memref.reinterpret_cast [[PARAM_5_]] to offset: {{.}}[[VAR_20_]]{{.}}, sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>> -// CHECK: affine.store [[VAR_18_]], [[VAR_reinterpret_cast_4_]][0] : memref<1xf32, strided<[1], offset: ?>> -// CHECK: [[VAR_21_:%.+]] = tensor.empty() : tensor<256xf32> -// CHECK-DAG: [[VAR_22_:%.+]] = linalg.fill ins([[VAR_6_]] : f32) outs([[VAR_21_]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK-DAG: [[VAR_23_:%.+]] = tensor.empty() : tensor<256xf32> -// CHECK: [[VAR_24_:%.+]] = linalg.fill ins([[VAR_18_]] : f32) outs([[VAR_23_]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK: scf.for [[VAR_arg12_1_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] : i32 { -// CHECK: [[VAR_25_5_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_5_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: {{.}}[[VAR_25_5_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<256xf32> -// CHECK-DAG: [[VAR_26_3_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_27_2_:%.+]] = arith.addi [[VAR_26_3_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_28_2_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK: [[VAR_29_2_:%.+]] = arith.minsi [[VAR_27_2_]], [[VAR_28_2_]] : index -// CHECK: [[VAR_30_2_:%.+]] = arith.subi [[VAR_29_2_]], [[VAR_26_3_]] : index -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_5_2_]][0] {{.}}[[VAR_30_2_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_6_2_:%.+]] = memref.subview [[RES_2_]][0] {{.}}[[VAR_30_2_]]{{.}} [1] : memref<256xf32> to memref> -// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_2_]]_6 : memref> to memref> -// CHECK-DAG: [[VAR_31_2_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<256xf32> -// CHECK-DAG: [[VAR_32_2_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_3_]] to offset: {{.}}[[VAR_32_2_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref<256xf32> -// CHECK-DAG: [[VAR_33_2_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_34_2_:%.+]] = arith.addi [[VAR_33_2_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_35_2_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK: [[VAR_36_2_:%.+]] = arith.minsi [[VAR_34_2_]], [[VAR_35_2_]] : index -// CHECK: [[VAR_37_1_:%.+]] = arith.subi [[VAR_36_2_]], [[VAR_33_2_]] : index -// CHECK-DAG: [[VAR_subview_9_:%.+]] = memref.subview [[VAR_reinterpret_cast_7_]][0] {{.}}[[VAR_37_1_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_10_:%.+]] = memref.subview [[RES_3_]][0] {{.}}[[VAR_37_1_]]{{.}} [1] : memref<256xf32> to memref> -// CHECK: memref.copy [[VAR_subview_9_]], [[VAR_subview_10_]] : memref> to memref> -// CHECK-DAG: [[VAR_38_1_:%.+]] = bufferization.to_tensor [[RES_3_]] restrict writable : memref<256xf32> -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK: [[VAR_41_1_:%.+]] = arith.addi [[VAR_39_1_]], [[VAR_40_1_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_11_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_41_1_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() : memref<256xf32> -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.addi [[VAR_42_1_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_44_6_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK: [[VAR_45_:%.+]] = arith.minsi [[VAR_43_1_]], [[VAR_44_6_]] : index -// CHECK: [[VAR_46_:%.+]] = arith.subi [[VAR_45_]], [[VAR_42_1_]] : index -// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpi slt, [[VAR_46_]], [[CST_256_]] : index -// CHECK: scf.if [[VAR_47_]] { -// CHECK: linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[RES_4_]] : memref<256xf32>) -// CHECK: } -// CHECK-DAG: [[VAR_subview_13_:%.+]] = memref.subview [[VAR_reinterpret_cast_11_]][0] {{.}}[[VAR_46_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_14_:%.+]] = memref.subview [[RES_4_]][0] {{.}}[[VAR_46_]]{{.}} [1] : memref<256xf32> to memref> -// CHECK: memref.copy [[VAR_subview_13_]], [[VAR_subview_14_]] : memref> to memref> -// CHECK: [[VAR_48_:%.+]] = bufferization.to_tensor [[RES_4_]] restrict writable : memref<256xf32> -// CHECK: [[VAR_49_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_48_]], [[VAR_22_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_48_]] : tensor<256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_61_:%.+]] = arith.subf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_61_]] : f32 -// CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_50_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_49_]], [[VAR_24_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_49_]] : tensor<256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_61_1_:%.+]] = arith.mulf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_61_1_]] : f32 -// CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_51_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_50_]], [[VAR_31_2_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_50_]] : tensor<256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_61_2_:%.+]] = arith.mulf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_61_2_]] : f32 -// CHECK: } -> tensor<256xf32> -// CHECK: [[VAR_52_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_51_]], [[VAR_38_1_]] : tensor<256xf32>, tensor<256xf32>) outs([[VAR_51_]] : tensor<256xf32>) { -// CHECK: ^bb0([[in_0:.+]]: f32, [[in_1:.+]]: f32, [[out:.+]]: f32): -// CHECK: [[VAR_61_3_:%.+]] = arith.addf [[in_0]], [[in_1]] : f32 -// CHECK: linalg.yield [[VAR_61_3_]] : f32 -// CHECK: } -> tensor<256xf32> -// CHECK-DAG: [[VAR_53_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index -// CHECK-DAG: [[VAR_54_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK: [[VAR_55_:%.+]] = arith.addi [[VAR_53_]], [[VAR_54_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_15_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_55_]]{{.}}, sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_56_:%.+]] = arith.index_cast [[VAR_arg12_1_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_57_:%.+]] = arith.addi [[VAR_56_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_58_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK: [[VAR_59_:%.+]] = arith.minsi [[VAR_57_]], [[VAR_58_]] : index -// CHECK: [[VAR_60_:%.+]] = arith.subi [[VAR_59_]], [[VAR_56_]] : index -// CHECK-DAG: [[VAR_extracted_slice_:%.+]] = tensor.extract_slice [[VAR_52_]][0] {{.}}[[VAR_60_]]{{.}} [1] : tensor<256xf32> to tensor -// CHECK-DAG: [[VAR_subview_16_:%.+]] = memref.subview [[VAR_reinterpret_cast_15_]][0] {{.}}[[VAR_60_]]{{.}} [1] : memref<256xf32, strided<[1], offset: ?>> to memref> -// CHECK: bufferization.materialize_in_destination [[VAR_extracted_slice_]] in writable [[VAR_subview_16_]] -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir b/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir deleted file mode 100644 index f0c99c7f..00000000 --- a/test/Conversion/TritonToLinalg/masked_ldst_1d.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) - { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> - %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %ldptr = tt.addptr %0, %2 : tensor<128x!tt.ptr>, tensor<128xi32> - %stptr = tt.addptr %1, %2 : tensor<128x!tt.ptr>, tensor<128xi32> - %nans = arith.constant dense<0xFF80> : tensor<128xbf16> - %5 = tt.splat %arg2 : i32 -> tensor<128xi32> - %mask = arith.cmpi slt, %2, %5 : tensor<128xi32> - %buff = tt.load %ldptr, %mask, %nans : tensor<128x!tt.ptr> - tt.store %stptr, %buff, %mask : tensor<128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0xFF80 : bf16 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 128 : index -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<128xbf16> -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_12:.*]] = arith.minsi %[[VAL_11]], %[[VAL_7]] : index -// CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_12]], %[[VAL_7]] : index -// CHECK: scf.if %[[VAL_15]] { -// CHECK: linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : memref<128xbf16>) -// CHECK: } -// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> -// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16> to memref> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref> to memref> -// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<128xbf16> -// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_18:.*]] = arith.minsi %[[VAL_17]], %[[VAL_7]] : index -// CHECK: %[[VAL_19:.*]] = tensor.extract_slice %[[VAL_16]][0] {{\[}}%[[VAL_18]]] [1] : tensor<128xbf16> to tensor -// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_9]][0] {{\[}}%[[VAL_18]]] [1] : memref<128xbf16, strided<[1]>> to memref> -// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_20]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/masked_ldst_2d.mlir b/test/Conversion/TritonToLinalg/masked_ldst_2d.mlir deleted file mode 100644 index 06cd3f0a..00000000 --- a/test/Conversion/TritonToLinalg/masked_ldst_2d.mlir +++ /dev/null @@ -1,108 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32, - %arg3 : i32 - ) - { - // Mimic a scenario where the raw pointer points to a buffer with dimension (1024, 1024) - // in row-major, but the actual tensor size is (arg2, arg3). - // We are trying to load a 128x256 sub-buffer starting at (2, 3). - // The resulting memref: - // offset = 3074 - // size[1] = 128 - // size[0] = 256 - // stride[0] = 1024 - // stride[1] = 1 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x256x!tt.ptr> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x256x!tt.ptr> - // horizontal index - %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %c2 = arith.constant 2 : i32 - %c2tensor = tt.splat %c2 : i32 -> tensor<128xi32> - %offset2 = arith.addi %2, %c2tensor : tensor<128xi32> - %3 = tt.expand_dims %offset2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %4 = tt.broadcast %3 : tensor<128x1xi32> -> tensor<128x256xi32> - // vertical index - %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %c3 = arith.constant 3 : i32 - %c3tensor = tt.splat %c3 : i32 -> tensor<256xi32> - %offset5 = arith.addi %5, %c3tensor : tensor<256xi32> - %c1024 = arith.constant 1024 : i32 - %c1024tensor = tt.splat %c1024 : i32 -> tensor<256xi32> - %scale5 = arith.muli %offset5, %c1024tensor : tensor<256xi32> - %6 = tt.expand_dims %scale5 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %7 = tt.broadcast %6 : tensor<1x256xi32> -> tensor<128x256xi32> - // combined index - %index = arith.addi %4, %7 : tensor<128x256xi32> - %ldptr = tt.addptr %0, %index : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - %stptr = tt.addptr %1, %index : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - // other value for masked load - %cnan = arith.constant 0xFF80 : bf16 - %nans = tt.splat %cnan : bf16 -> tensor<128x256xbf16> - // horizontal mask - %8 = tt.splat %arg2 : i32 -> tensor<128xi32> - %9 = arith.cmpi slt, %offset2, %8 : tensor<128xi32> - %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi1> -> tensor<128x1xi1> - %11 = tt.broadcast %10 : tensor<128x1xi1> -> tensor<128x256xi1> - // vertical mask - %12 = tt.splat %arg3 : i32 -> tensor<256xi32> - %13 = arith.cmpi slt, %offset5, %12 : tensor<256xi32> - %14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<256xi1> -> tensor<1x256xi1> - %15 = tt.broadcast %14 : tensor<1x256xi1> -> tensor<128x256xi1> - // combined mask - %mask = arith.andi %11, %15 : tensor<128x256xi1> - // dim0 = min(%arg2, 128), dim1 = min(%arg3, 256) - // TODO: need reinterpret cast - %buff = tt.load %ldptr, %mask, %nans : tensor<128x256x!tt.ptr> - tt.store %stptr, %buff, %mask : tensor<128x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3074 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 259 : index -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 130 : index -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 0xFF80 : bf16 -// CHECK: %[[VAL_16:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_7]]], sizes: [128, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<128x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_7]]], sizes: [128, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<128x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<128x256xbf16> -// CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_20:.*]] = arith.minsi %[[VAL_19]], %[[VAL_14]] : index -// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_20]], %[[VAL_10]] : index -// CHECK: %[[VAL_22:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_23:.*]] = arith.minsi %[[VAL_22]], %[[VAL_13]] : index -// CHECK: %[[VAL_24:.*]] = arith.subi %[[VAL_23]], %[[VAL_9]] : index -// CHECK: %[[VAL_25:.*]] = arith.minsi %[[VAL_21]], %[[VAL_12]] : index -// CHECK: %[[VAL_26:.*]] = arith.minsi %[[VAL_24]], %[[VAL_11]] : index -// CHECK: %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_12]] : index -// CHECK: %[[VAL_30:.*]] = arith.cmpi slt, %[[VAL_26]], %[[VAL_11]] : index -// CHECK: %[[VAL_31:.*]] = arith.ori %[[VAL_29]], %[[VAL_30]] : i1 -// CHECK: scf.if %[[VAL_31]] { -// CHECK: linalg.fill ins(%[[VAL_15]] : bf16) outs(%[[VAL_18]] : memref<128x256xbf16>) -// CHECK: } -// CHECK: %[[VAL_27:.*]] = memref.subview %[[VAL_16]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : memref<128x256xbf16, strided<[1, ?], offset: ?>> to memref> -// CHECK: %[[VAL_28:.*]] = memref.subview %[[VAL_18]][0, 0] {{\[}}%[[VAL_25]], %[[VAL_26]]] [1, 1] : memref<128x256xbf16> to memref> -// CHECK: memref.copy %[[VAL_27]], %[[VAL_28]] : memref> to memref> -// CHECK: %[[VAL_32:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<128x256xbf16> -// CHECK: %[[VAL_33:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_34:.*]] = arith.minsi %[[VAL_33]], %[[VAL_14]] : index -// CHECK: %[[VAL_35:.*]] = arith.subi %[[VAL_34]], %[[VAL_10]] : index -// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_37:.*]] = arith.minsi %[[VAL_36]], %[[VAL_13]] : index -// CHECK: %[[VAL_38:.*]] = arith.subi %[[VAL_37]], %[[VAL_9]] : index -// CHECK: %[[VAL_39:.*]] = arith.minsi %[[VAL_35]], %[[VAL_12]] : index -// CHECK: %[[VAL_40:.*]] = arith.minsi %[[VAL_38]], %[[VAL_11]] : index -// CHECK: %[[VAL_41:.*]] = tensor.extract_slice %[[VAL_32]][0, 0] {{\[}}%[[VAL_39]], %[[VAL_40]]] [1, 1] : tensor<128x256xbf16> to tensor -// CHECK: %[[VAL_42:.*]] = memref.subview %[[VAL_17]][0, 0] {{\[}}%[[VAL_39]], %[[VAL_40]]] [1, 1] : memref<128x256xbf16, strided<[1, ?], offset: ?>> to memref> -// CHECK: bufferization.materialize_in_destination %[[VAL_41]] in writable %[[VAL_42]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir b/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir deleted file mode 100644 index 475d5327..00000000 --- a/test/Conversion/TritonToLinalg/masked_ldst_sitofp_other.mlir +++ /dev/null @@ -1,47 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) - { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> - %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %ldptr = tt.addptr %0, %2 : tensor<128x!tt.ptr>, tensor<128xi32> - %stptr = tt.addptr %1, %2 : tensor<128x!tt.ptr>, tensor<128xi32> - %c7_i32 = arith.constant 7 : i32 - %splat_c7_i32 = tt.splat %c7_i32 : i32 -> tensor<128xi32> - %splat_c7_bf16 = arith.sitofp %splat_c7_i32 : tensor<128xi32> to tensor<128xbf16> - %5 = tt.splat %arg2 : i32 -> tensor<128xi32> - %mask = arith.cmpi slt, %2, %5 : tensor<128xi32> - %buff = tt.load %ldptr, %mask, %splat_c7_bf16 : tensor<128x!tt.ptr> - tt.store %stptr, %buff, %mask : tensor<128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 7.000000e+00 : bf16 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 128 : index -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<128xbf16> -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_12:.*]] = arith.minsi %[[VAL_11]], %[[VAL_7]] : index -// CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_12]], %[[VAL_7]] : index -// CHECK: scf.if %[[VAL_15]] { -// CHECK: linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : memref<128xbf16>) -// CHECK: } -// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_8]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16, strided<[1]>> to memref> -// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_10]][0] {{\[}}%[[VAL_12]]] [1] : memref<128xbf16> to memref> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref> to memref> -// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<128xbf16> -// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_18:.*]] = arith.minsi %[[VAL_17]], %[[VAL_7]] : index -// CHECK: %[[VAL_19:.*]] = tensor.extract_slice %[[VAL_16]][0] {{\[}}%[[VAL_18]]] [1] : tensor<128xbf16> to tensor -// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_9]][0] {{\[}}%[[VAL_18]]] [1] : memref<128xbf16, strided<[1]>> to memref> -// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_20]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir b/test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir deleted file mode 100644 index e2134d84..00000000 --- a/test/Conversion/TritonToLinalg/tensor_indices_loop_iterarg_with_masks.mlir +++ /dev/null @@ -1,79 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s - -// IR from python/examples/test_tensor_index_iterargs.py -module { - tt.func public @addptr_with_masks(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %cst = arith.constant dense<-1.100000e+01> : tensor<4xf32> - %c1_i32 = arith.constant 1 : i32 - %c4_i32 = arith.constant 4 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst_0 = arith.constant dense<4> : tensor<4xi32> - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = tt.splat %arg2 : i32 -> tensor<4xi32> - %2 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %3 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %4:2 = scf.for %arg3 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg4 = %0, %arg5 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { - %5 = arith.cmpi slt, %arg4, %1 : tensor<4xi32> - %6 = tt.addptr %2, %arg4 : tensor<4x!tt.ptr>, tensor<4xi32> - %7 = tt.load %6, %5, %cst : tensor<4x!tt.ptr> - %8 = tt.addptr %3, %arg5 : tensor<4x!tt.ptr>, tensor<4xi32> - tt.store %8, %7 : tensor<4x!tt.ptr> - %9 = arith.addi %arg4, %cst_0 : tensor<4xi32> - %10 = arith.addi %arg5, %cst_0 : tensor<4xi32> - scf.yield %9, %10 : tensor<4xi32>, tensor<4xi32> - } - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @addptr_with_masks -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[CST_minus_1_dot_100000_:%.+]] = arith.constant -1.100000e+01 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_4_]] : i32) outs([[VAR_0_]] : tensor<4xi32>) -> tensor<4xi32> -// CHECK-DAG: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { -// CHECK: ^bb0([[IN_0_:%.+]]: i32): -// CHECK: [[VAR_4_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i32 -// CHECK: linalg.yield [[VAR_5_]] : i32 -// CHECK: } -> tensor<4xi32> -// CHECK-DAG: [[VAR_3_:%.+]]:4 = scf.for [[VAR_arg9_:%.+]] = [[CST_0_]] to [[CST_4_]] step [[CST_1_]] iter_args([[VAR_arg10_:%.+]] = [[VAR_2_]], [[VAR_arg11_:%.+]] = [[CST_0_1_]], [[VAR_arg12_:%.+]] = [[VAR_2_]], [[VAR_arg13_:%.+]] = [[CST_0_1_]]) -> (tensor<4xi32>, index, tensor<4xi32>, index) : i32 { -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_arg11_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> -// CHECK-DAG: [[VAR_4_1_:%.+]] = arith.addi [[VAR_arg11_]], [[CST_4_1_]] : index -// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK: [[VAR_6_:%.+]] = arith.minsi [[VAR_4_1_]], [[VAR_5_1_]] : index -// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_6_]], [[VAR_arg11_]] : index -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4xf32> -// CHECK: [[VAR_8_:%.+]] = arith.cmpi slt, [[VAR_7_]], [[CST_4_1_]] : index -// CHECK: scf.if [[VAR_8_]] { -// CHECK: linalg.fill ins([[CST_minus_1_dot_100000_]] : f32) outs([[RES_]] : memref<4xf32>) -// CHECK: } -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_]][0] {{.}}[[VAR_7_]]{{.}} [1] : memref<4xf32, strided<[?], offset: ?>> to memref> -// CHECK-DAG: [[VAR_subview_0_:%.+]] = memref.subview [[RES_]][0] {{.}}[[VAR_7_]]{{.}} [1] : memref<4xf32> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_0_]] : memref> to memref> -// CHECK-DAG: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4xf32> -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}} : memref<*xf32> to memref<4xf32, strided<[?], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_9_]] in writable [[VAR_reinterpret_cast_1_]] : (tensor<4xf32>, memref<4xf32, strided<[?], offset: ?>>) -> () -// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg10_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg10_]] : tensor<4xi32>) { -// CHECK: ^bb0([[IN_1_:%.+]]: i32, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: i32): -// CHECK: [[VAR_13_:%.+]] = arith.addi [[IN_1_]], [[IN_2_]] : i32 -// CHECK: linalg.yield [[VAR_13_]] : i32 -// CHECK: } -> tensor<4xi32> -// CHECK: [[VAR_11_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_arg12_]], [[VAR_1_]] : tensor<4xi32>, tensor<4xi32>) outs([[VAR_arg12_]] : tensor<4xi32>) { -// CHECK: ^bb0([[IN_4_:%.+]]: i32, [[IN_5_:%.+]]: i32, [[IN_6_:%.+]]: i32): -// CHECK: [[VAR_13_1_:%.+]] = arith.addi [[IN_4_]], [[IN_5_]] : i32 -// CHECK: linalg.yield [[VAR_13_1_]] : i32 -// CHECK: } -> tensor<4xi32> -// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_arg13_]], [[CST_4_1_]] : index -// CHECK: scf.yield [[VAR_10_]], [[VAR_4_1_]], [[VAR_11_]], [[VAR_12_]] : tensor<4xi32>, index, tensor<4xi32>, index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-01-vector-add.mlir b/test/Conversion/TritonToStructured/kernel-01-vector-add.mlir index 0c7fc293..9a0bbfb7 100644 --- a/test/Conversion/TritonToStructured/kernel-01-vector-add.mlir +++ b/test/Conversion/TritonToStructured/kernel-01-vector-add.mlir @@ -39,24 +39,27 @@ module { // CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_6_]], [[CST_1024_]] : index // CHECK-DAG: [[VAR_8_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index // CHECK: [[VAR_9_:%.+]] = arith.minsi [[VAR_7_]], [[VAR_8_]] : index -// CHECK: [[VAR_10_:%.+]] = arith.subi [[VAR_9_]], [[VAR_6_]] : index -// CHECK-DAG: [[VAR_11_:%.+]] = "tts.load"([[VAR_5_]], [[VAR_10_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<1024x!tt.ptr>, index) -> tensor<1024xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_3_]]{{.}}, shape: [0], order: [] : to tensor<1024x!tt.ptr> -// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK: [[VAR_10_:%.+]] = arith.maxsi [[VAR_9_]], [[VAR_6_]] : index +// CHECK: [[VAR_11_:%.+]] = arith.subi [[VAR_10_]], [[VAR_6_]] : index +// CHECK-DAG: [[VAR_12_:%.+]] = "tts.load"([[VAR_5_]], [[VAR_11_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<1024x!tt.ptr>, index) -> tensor<1024xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_3_]]{{.}}, shape: [0], order: [] : to tensor<1024x!tt.ptr> +// CHECK-DAG: [[VAR_14_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_14_:%.+]] = arith.addi [[VAR_13_]], [[CST_1024_]] : index -// CHECK-DAG: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_16_:%.+]] = arith.minsi [[VAR_14_]], [[VAR_15_]] : index -// CHECK: [[VAR_17_:%.+]] = arith.subi [[VAR_16_]], [[VAR_13_]] : index -// CHECK: [[VAR_18_:%.+]] = "tts.load"([[VAR_12_]], [[VAR_17_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<1024x!tt.ptr>, index) -> tensor<1024xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_11_]], [[VAR_18_]] : tensor<1024xf32> -// CHECK-DAG: [[VAR_20_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, shape: [0], order: [] : to tensor<1024x!tt.ptr> -// CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addi [[VAR_14_]], [[CST_1024_]] : index +// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_15_]], [[VAR_16_]] : index +// CHECK: [[VAR_18_:%.+]] = arith.maxsi [[VAR_17_]], [[VAR_14_]] : index +// CHECK: [[VAR_19_:%.+]] = arith.subi [[VAR_18_]], [[VAR_14_]] : index +// CHECK: [[VAR_20_:%.+]] = "tts.load"([[VAR_13_]], [[VAR_19_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<1024x!tt.ptr>, index) -> tensor<1024xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = arith.addf [[VAR_12_]], [[VAR_20_]] : tensor<1024xf32> +// CHECK-DAG: [[VAR_22_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [1024], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, shape: [0], order: [] : to tensor<1024x!tt.ptr> +// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_22_:%.+]] = arith.addi [[VAR_21_]], [[CST_1024_]] : index -// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_24_:%.+]] = arith.minsi [[VAR_22_]], [[VAR_23_]] : index -// CHECK: [[VAR_25_:%.+]] = arith.subi [[VAR_24_]], [[VAR_21_]] : index -// CHECK: "tts.store"([[VAR_20_]], [[VAR_19_]], [[VAR_25_]]) <{static_mask_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xf32>, index) -> () +// CHECK-DAG: [[VAR_24_:%.+]] = arith.addi [[VAR_23_]], [[CST_1024_]] : index +// CHECK-DAG: [[VAR_25_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_26_:%.+]] = arith.minsi [[VAR_24_]], [[VAR_25_]] : index +// CHECK: [[VAR_27_:%.+]] = arith.maxsi [[VAR_26_]], [[VAR_23_]] : index +// CHECK: [[VAR_28_:%.+]] = arith.subi [[VAR_27_]], [[VAR_23_]] : index +// CHECK: "tts.store"([[VAR_22_]], [[VAR_21_]], [[VAR_28_]]) <{static_mask_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xf32>, index) -> () // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-02-fused-softmax.mlir b/test/Conversion/TritonToStructured/kernel-02-fused-softmax.mlir index 69bd1202..f7ce37c0 100644 --- a/test/Conversion/TritonToStructured/kernel-02-fused-softmax.mlir +++ b/test/Conversion/TritonToStructured/kernel-02-fused-softmax.mlir @@ -39,36 +39,39 @@ module { } // CHECK: tt.func public @softmax_kernel_012345([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 // CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 // CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-DAG: [[VAR_3_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, shape: [0], order: [] : to tensor<128x!tt.ptr> // CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index // CHECK: [[VAR_5_:%.+]] = arith.minsi [[VAR_4_]], [[CST_128_]] : index -// CHECK: [[VAR_6_:%.+]] = "tts.load"([[VAR_3_]], [[VAR_5_]], [[CST_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x!tt.ptr>, index, f32) -> tensor<128xf32> -// CHECK: [[VAR_7_:%.+]] = "tt.reduce"([[VAR_6_]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0([[arg5_:%.+]]: f32, [[arg6_:%.+]]: f32): -// CHECK: [[VAR_19_:%.+]] = arith.cmpf ogt, [[arg5_]], [[arg6_]] : f32 -// CHECK: [[VAR_20_:%.+]] = arith.select [[VAR_19_]], [[arg5_]], [[arg6_]] : f32 -// CHECK: tt.reduce.return [[VAR_20_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.maxsi [[VAR_5_]], [[CST_0_1_]] : index +// CHECK: [[VAR_7_:%.+]] = "tts.load"([[VAR_3_]], [[VAR_6_]], [[CST_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x!tt.ptr>, index, f32) -> tensor<128xf32> +// CHECK: [[VAR_8_:%.+]] = "tt.reduce"([[VAR_7_]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0([[IN_0_:%.+]]: f32, [[IN_1_:%.+]]: f32): +// CHECK: [[VAR_21_:%.+]] = arith.cmpf ogt, [[IN_0_]], [[IN_1_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.select [[VAR_21_]], [[IN_0_]], [[IN_1_]] : f32 +// CHECK: tt.reduce.return [[VAR_22_]] : f32 // CHECK: }) : (tensor<128xf32>) -> f32 -// CHECK: [[VAR_8_:%.+]] = tt.splat [[VAR_7_]] : f32 -> tensor<128xf32> -// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : tensor<128xf32> -// CHECK: [[VAR_10_:%.+]] = math.exp [[VAR_9_]] : tensor<128xf32> -// CHECK: [[VAR_11_:%.+]] = "tt.reduce"([[VAR_10_]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0([[arg5_]]: f32, [[arg6_]]: f32): -// CHECK: [[VAR_19_1_:%.+]] = arith.addf [[arg5_]], [[arg6_]] : f32 -// CHECK: tt.reduce.return [[VAR_19_1_]] : f32 +// CHECK: [[VAR_9_:%.+]] = tt.splat [[VAR_8_]] : f32 -> tensor<128xf32> +// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_7_]], [[VAR_9_]] : tensor<128xf32> +// CHECK: [[VAR_11_:%.+]] = math.exp [[VAR_10_]] : tensor<128xf32> +// CHECK: [[VAR_12_:%.+]] = "tt.reduce"([[VAR_11_]]) <{axis = 0 : i32}> ({ +// CHECK: ^bb0([[IN_2_:%.+]]: f32, [[IN_3_:%.+]]: f32): +// CHECK: [[VAR_21_1_:%.+]] = arith.addf [[IN_2_]], [[IN_3_]] : f32 +// CHECK: tt.reduce.return [[VAR_21_1_]] : f32 // CHECK: }) : (tensor<128xf32>) -> f32 -// CHECK: [[VAR_12_:%.+]] = tt.splat [[VAR_11_]] : f32 -> tensor<128xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = arith.divf [[VAR_10_]], [[VAR_12_]] : tensor<128xf32> -// CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 -// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index -// CHECK-DAG: [[VAR_16_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128], strides: [1], offsets: {{.}}[[VAR_15_]]{{.}}, shape: [0], order: [] : to tensor<128x!tt.ptr> -// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_18_:%.+]] = arith.minsi [[VAR_17_]], [[CST_128_]] : index -// CHECK: "tts.store"([[VAR_16_]], [[VAR_13_]], [[VAR_18_]]) <{static_mask_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xf32>, index) -> () +// CHECK: [[VAR_13_:%.+]] = tt.splat [[VAR_12_]] : f32 -> tensor<128xf32> +// CHECK-DAG: [[VAR_14_:%.+]] = arith.divf [[VAR_11_]], [[VAR_13_]] : tensor<128xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index +// CHECK-DAG: [[VAR_17_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128], strides: [1], offsets: {{.}}[[VAR_16_]]{{.}}, shape: [0], order: [] : to tensor<128x!tt.ptr> +// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_19_:%.+]] = arith.minsi [[VAR_18_]], [[CST_128_]] : index +// CHECK: [[VAR_20_:%.+]] = arith.maxsi [[VAR_19_]], [[CST_0_1_]] : index +// CHECK: "tts.store"([[VAR_17_]], [[VAR_14_]], [[VAR_20_]]) <{static_mask_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xf32>, index) -> () // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-03-matrix-multiplication.mlir b/test/Conversion/TritonToStructured/kernel-03-matrix-multiplication.mlir index d54977da..39e93253 100644 --- a/test/Conversion/TritonToStructured/kernel-03-matrix-multiplication.mlir +++ b/test/Conversion/TritonToStructured/kernel-03-matrix-multiplication.mlir @@ -97,10 +97,10 @@ module { } // CHECK: tt.func public @matmul_kernel_0123456789101112131415([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32> // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32> // CHECK-DAG: [[CST_63_:%.+]] = arith.constant 63 : i32 // CHECK-DAG: [[CST_255_:%.+]] = arith.constant 255 : i32 // CHECK-DAG: [[CST_127_:%.+]] = arith.constant 127 : i32 @@ -153,16 +153,16 @@ module { // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_32_:%.+]] = arith.index_cast [[VAR_31_]] : i32 to index // CHECK-DAG: [[VAR_33_:%.+]]:3 = scf.for [[VAR_arg12_:%.+]] = [[CST_0_1_]] to [[VAR_6_]] step [[CST_1_]] iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_24_]], [[VAR_arg15_:%.+]] = [[CST_0_]]) -> (tensor<128x256xf32>, index, index) : i32 { -// CHECK-DAG: [[VAR_52_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [64, 256], strides: {{.}}[[VAR_26_]], [[VAR_27_]]{{.}}, offsets: {{.}}[[PARAM_1_]]5, [[VAR_28_]]{{.}}, shape: [0, 0], order: [] : to tensor<64x256x!tt.ptr> -// CHECK-DAG: [[VAR_53_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128, 64], strides: {{.}}[[VAR_23_]], [[VAR_25_]]{{.}}, offsets: {{.}}[[VAR_arg14_]], [[CST_0_]]{{.}}, shape: [0, 0], order: [] : to tensor<128x64x!tt.ptr> +// CHECK-DAG: [[VAR_54_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [64, 256], strides: {{.}}[[VAR_26_]], [[VAR_27_]]{{.}}, offsets: {{.}}[[VAR_arg15_]], [[VAR_28_]]{{.}}, shape: [0, 0], order: [] : to tensor<64x256x!tt.ptr> +// CHECK-DAG: [[VAR_55_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128, 64], strides: {{.}}[[VAR_23_]], [[VAR_25_]]{{.}}, offsets: {{.}}[[VAR_arg14_]], [[CST_0_]]{{.}}, shape: [0, 0], order: [] : to tensor<128x64x!tt.ptr> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_54_:%.+]] = "tts.load"([[VAR_53_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x64x!tt.ptr>) -> tensor<128x64xbf16> -// CHECK-DAG: [[VAR_55_:%.+]] = "tts.load"([[VAR_52_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<64x256x!tt.ptr>) -> tensor<64x256xbf16> -// CHECK: [[VAR_56_:%.+]] = tt.dot [[VAR_54_]], [[VAR_55_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> -// CHECK-DAG: [[VAR_57_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_56_]] : tensor<128x256xf32> -// CHECK-DAG: [[VAR_58_:%.+]] = arith.addi [[VAR_arg14_]], [[VAR_30_]] : index -// CHECK-DAG: [[VAR_59_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_32_]] : index -// CHECK: scf.yield [[VAR_57_]], [[VAR_58_]], [[VAR_59_]] : tensor<128x256xf32>, index, index +// CHECK-DAG: [[VAR_56_:%.+]] = "tts.load"([[VAR_55_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x64x!tt.ptr>) -> tensor<128x64xbf16> +// CHECK-DAG: [[VAR_57_:%.+]] = "tts.load"([[VAR_54_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<64x256x!tt.ptr>) -> tensor<64x256xbf16> +// CHECK: [[VAR_58_:%.+]] = tt.dot [[VAR_56_]], [[VAR_57_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> +// CHECK-DAG: [[VAR_59_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_58_]] : tensor<128x256xf32> +// CHECK-DAG: [[VAR_60_:%.+]] = arith.addi [[VAR_arg14_]], [[VAR_30_]] : index +// CHECK-DAG: [[VAR_61_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_32_]] : index +// CHECK: scf.yield [[VAR_59_]], [[VAR_60_]], [[VAR_61_]] : tensor<128x256xf32>, index, index // CHECK: } // CHECK-DAG: [[VAR_34_:%.+]] = arith.truncf [[VAR_33_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16> // CHECK-DAG: [[VAR_35_:%.+]] = arith.index_cast [[PARAM_10_]] : i32 to index @@ -176,15 +176,17 @@ module { // CHECK-DAG: [[VAR_41_:%.+]] = arith.addi [[VAR_40_]], [[CST_128_]] : index // CHECK-DAG: [[VAR_42_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index // CHECK: [[VAR_43_:%.+]] = arith.minsi [[VAR_41_]], [[VAR_42_]] : index -// CHECK-DAG: [[VAR_44_:%.+]] = arith.subi [[VAR_43_]], [[VAR_40_]] : index -// CHECK-DAG: [[VAR_45_:%.+]] = arith.index_cast [[VAR_20_]] : i32 to index +// CHECK: [[VAR_44_:%.+]] = arith.maxsi [[VAR_43_]], [[VAR_40_]] : index +// CHECK-DAG: [[VAR_45_:%.+]] = arith.subi [[VAR_44_]], [[VAR_40_]] : index +// CHECK-DAG: [[VAR_46_:%.+]] = arith.index_cast [[VAR_20_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addi [[VAR_45_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_47_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_48_:%.+]] = arith.minsi [[VAR_46_]], [[VAR_47_]] : index -// CHECK-DAG: [[VAR_49_:%.+]] = arith.subi [[VAR_48_]], [[VAR_45_]] : index -// CHECK-DAG: [[VAR_50_:%.+]] = arith.minsi [[VAR_44_]], [[CST_128_]] : index -// CHECK: [[VAR_51_:%.+]] = arith.minsi [[VAR_49_]], [[CST_256_]] : index -// CHECK: "tts.store"([[VAR_39_]], [[VAR_34_]], [[VAR_50_]], [[VAR_51_]]) <{static_mask_dims = array}> : (tensor<128x256x!tt.ptr>, tensor<128x256xbf16>, index, index) -> () +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addi [[VAR_46_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_48_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_49_:%.+]] = arith.minsi [[VAR_47_]], [[VAR_48_]] : index +// CHECK: [[VAR_50_:%.+]] = arith.maxsi [[VAR_49_]], [[VAR_46_]] : index +// CHECK-DAG: [[VAR_51_:%.+]] = arith.subi [[VAR_50_]], [[VAR_46_]] : index +// CHECK-DAG: [[VAR_52_:%.+]] = arith.minsi [[VAR_45_]], [[CST_128_]] : index +// CHECK: [[VAR_53_:%.+]] = arith.minsi [[VAR_51_]], [[CST_256_]] : index +// CHECK: "tts.store"([[VAR_39_]], [[VAR_34_]], [[VAR_52_]], [[VAR_53_]]) <{static_mask_dims = array}> : (tensor<128x256x!tt.ptr>, tensor<128x256xbf16>, index, index) -> () // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-05-layer-norm-dwdb.mlir b/test/Conversion/TritonToStructured/kernel-05-layer-norm-dwdb.mlir index 5e2788fc..39d2491a 100644 --- a/test/Conversion/TritonToStructured/kernel-05-layer-norm-dwdb.mlir +++ b/test/Conversion/TritonToStructured/kernel-05-layer-norm-dwdb.mlir @@ -61,69 +61,73 @@ module { } // CHECK: tt.func public @_layer_norm_bwd_dwdb_0123456([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: !tt.ptr, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<256x256xf32> +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : tensor<256x256xf32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 // CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[CST_256_1_]] : i32 // CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index -// CHECK-DAG: [[VAR_6_:%.+]]:2 = scf.for [[VAR_arg6_:%.+]] = [[CST_0_]] to [[PARAM_4_]] step [[CST_256_1_]] iter_args([[VAR_arg7_:%.+]] = [[VAR_cst_]], [[VAR_arg8_:%.+]] = [[VAR_cst_]]) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { -// CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index -// CHECK-DAG: [[VAR_22_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_23_:%.+]] = arith.muli [[VAR_21_]], [[VAR_22_]] : index -// CHECK-DAG: [[VAR_24_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256, 256], strides: {{.}}[[VAR_22_]], 1], offsets: {{.}}[[VAR_23_]], [[VAR_5_]]{{.}}, shape: [0, 0], order: [] : to tensor<256x256x!tt.ptr> -// CHECK-DAG: [[VAR_25_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index +// CHECK-DAG: [[VAR_6_:%.+]]:2 = scf.for [[VAR_arg6_:%.+]] = [[CST_0_]] to [[PARAM_4_]] step [[CST_256_1_]] iter_args([[VAR_arg7_:%.+]] = [[VAR_cst_0_]], [[VAR_arg8_:%.+]] = [[VAR_cst_0_]]) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { +// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index +// CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_25_:%.+]] = arith.muli [[VAR_23_]], [[VAR_24_]] : index +// CHECK-DAG: [[VAR_26_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256, 256], strides: {{.}}[[VAR_24_]], 1], offsets: {{.}}[[VAR_25_]], [[VAR_5_]]{{.}}, shape: [0, 0], order: [] : to tensor<256x256x!tt.ptr> +// CHECK-DAG: [[VAR_27_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.addi [[VAR_25_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_27_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_28_:%.+]] = arith.minsi [[VAR_26_]], [[VAR_27_]] : index -// CHECK-DAG: [[VAR_29_:%.+]] = arith.subi [[VAR_28_]], [[VAR_25_]] : index -// CHECK-DAG: [[VAR_30_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_28_:%.+]] = arith.addi [[VAR_27_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_29_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_30_:%.+]] = arith.minsi [[VAR_28_]], [[VAR_29_]] : index +// CHECK: [[VAR_31_:%.+]] = arith.maxsi [[VAR_30_]], [[VAR_27_]] : index +// CHECK-DAG: [[VAR_32_:%.+]] = arith.subi [[VAR_31_]], [[VAR_27_]] : index +// CHECK-DAG: [[VAR_33_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_31_:%.+]] = arith.addi [[VAR_30_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_32_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_33_:%.+]] = arith.minsi [[VAR_31_]], [[VAR_32_]] : index -// CHECK-DAG: [[VAR_34_:%.+]] = arith.subi [[VAR_33_]], [[VAR_30_]] : index -// CHECK-DAG: [[VAR_35_:%.+]] = arith.minsi [[VAR_29_]], [[CST_256_]] : index -// CHECK: [[VAR_36_:%.+]] = arith.minsi [[VAR_34_]], [[CST_256_]] : index -// CHECK: [[VAR_37_:%.+]] = "tts.load"([[VAR_24_]], [[VAR_35_]], [[VAR_36_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x256x!tt.ptr>, index, index, f32) -> tensor<256x256xf32> -// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[VAR_arg7_]], [[VAR_37_]] : tensor<256x256xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index -// CHECK-DAG: [[VAR_40_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_41_:%.+]] = arith.muli [[VAR_39_]], [[VAR_40_]] : index -// CHECK-DAG: [[VAR_42_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256, 256], strides: {{.}}[[VAR_40_]], 1], offsets: {{.}}[[VAR_41_]], [[VAR_4_]]{{.}}, shape: [0, 0], order: [] : to tensor<256x256x!tt.ptr> +// CHECK-DAG: [[VAR_34_:%.+]] = arith.addi [[VAR_33_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_35_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_36_:%.+]] = arith.minsi [[VAR_34_]], [[VAR_35_]] : index +// CHECK: [[VAR_37_:%.+]] = arith.maxsi [[VAR_36_]], [[VAR_33_]] : index +// CHECK-DAG: [[VAR_38_:%.+]] = arith.subi [[VAR_37_]], [[VAR_33_]] : index +// CHECK-DAG: [[VAR_39_:%.+]] = arith.minsi [[VAR_32_]], [[CST_256_]] : index +// CHECK: [[VAR_40_:%.+]] = arith.minsi [[VAR_38_]], [[CST_256_]] : index +// CHECK: [[VAR_41_:%.+]] = "tts.load"([[VAR_26_]], [[VAR_39_]], [[VAR_40_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x256x!tt.ptr>, index, index, f32) -> tensor<256x256xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.addf [[VAR_arg7_]], [[VAR_41_]] : tensor<256x256xf32> // CHECK-DAG: [[VAR_43_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index +// CHECK-DAG: [[VAR_44_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_45_:%.+]] = arith.muli [[VAR_43_]], [[VAR_44_]] : index +// CHECK-DAG: [[VAR_46_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256, 256], strides: {{.}}[[VAR_44_]], 1], offsets: {{.}}[[VAR_45_]], [[VAR_4_]]{{.}}, shape: [0, 0], order: [] : to tensor<256x256x!tt.ptr> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.index_cast [[VAR_arg6_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_44_:%.+]] = arith.addi [[VAR_43_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_45_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_46_:%.+]] = arith.minsi [[VAR_44_]], [[VAR_45_]] : index -// CHECK-DAG: [[VAR_47_:%.+]] = arith.subi [[VAR_46_]], [[VAR_43_]] : index -// CHECK-DAG: [[VAR_48_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK-DAG: [[VAR_48_:%.+]] = arith.addi [[VAR_47_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_49_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_50_:%.+]] = arith.minsi [[VAR_48_]], [[VAR_49_]] : index +// CHECK: [[VAR_51_:%.+]] = arith.maxsi [[VAR_50_]], [[VAR_47_]] : index +// CHECK-DAG: [[VAR_52_:%.+]] = arith.subi [[VAR_51_]], [[VAR_47_]] : index +// CHECK-DAG: [[VAR_53_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.addi [[VAR_48_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_50_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_51_:%.+]] = arith.minsi [[VAR_49_]], [[VAR_50_]] : index -// CHECK-DAG: [[VAR_52_:%.+]] = arith.subi [[VAR_51_]], [[VAR_48_]] : index -// CHECK-DAG: [[VAR_53_:%.+]] = arith.minsi [[VAR_47_]], [[CST_256_]] : index -// CHECK: [[VAR_54_:%.+]] = arith.minsi [[VAR_52_]], [[CST_256_]] : index -// CHECK: [[VAR_55_:%.+]] = "tts.load"([[VAR_42_]], [[VAR_53_]], [[VAR_54_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x256x!tt.ptr>, index, index, f32) -> tensor<256x256xf32> -// CHECK: [[VAR_56_:%.+]] = arith.addf [[VAR_arg8_]], [[VAR_55_]] : tensor<256x256xf32> -// CHECK: scf.yield [[VAR_38_]], [[VAR_56_]] : tensor<256x256xf32>, tensor<256x256xf32> +// CHECK-DAG: [[VAR_54_:%.+]] = arith.addi [[VAR_53_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_55_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_56_:%.+]] = arith.minsi [[VAR_54_]], [[VAR_55_]] : index +// CHECK: [[VAR_57_:%.+]] = arith.maxsi [[VAR_56_]], [[VAR_53_]] : index +// CHECK-DAG: [[VAR_58_:%.+]] = arith.subi [[VAR_57_]], [[VAR_53_]] : index +// CHECK-DAG: [[VAR_59_:%.+]] = arith.minsi [[VAR_52_]], [[CST_256_]] : index +// CHECK: [[VAR_60_:%.+]] = arith.minsi [[VAR_58_]], [[CST_256_]] : index +// CHECK: [[VAR_61_:%.+]] = "tts.load"([[VAR_46_]], [[VAR_59_]], [[VAR_60_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x256x!tt.ptr>, index, index, f32) -> tensor<256x256xf32> +// CHECK: [[VAR_62_:%.+]] = arith.addf [[VAR_arg8_]], [[VAR_61_]] : tensor<256x256xf32> +// CHECK: scf.yield [[VAR_42_]], [[VAR_62_]] : tensor<256x256xf32>, tensor<256x256xf32> // CHECK: } // CHECK: [[VAR_7_:%.+]] = "tt.reduce"([[VAR_6_]]#0) <{axis = 0 : i32}> ({ -// CHECK: ^bb0([[VAR_arg6_]]: f32, [[VAR_arg7_]]: f32): -// CHECK: [[VAR_21_1_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_arg7_]] : f32 -// CHECK: tt.reduce.return [[VAR_21_1_]] : f32 +// CHECK: ^bb0([[VAR_arg6_1_:%.+]]: f32, [[VAR_arg7_1_:%.+]]: f32): +// CHECK: [[VAR_23_1_:%.+]] = arith.addf [[VAR_arg6_1_]], [[VAR_arg7_1_]] : f32 +// CHECK: tt.reduce.return [[VAR_23_1_]] : f32 // CHECK: }) : (tensor<256x256xf32>) -> tensor<256xf32> // CHECK: [[VAR_8_:%.+]] = "tt.reduce"([[VAR_6_]]#1) <{axis = 0 : i32}> ({ -// CHECK: ^bb0([[VAR_arg6_]]: f32, [[VAR_arg7_]]: f32): -// CHECK: [[VAR_21_2_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_arg7_]] : f32 -// CHECK: tt.reduce.return [[VAR_21_2_]] : f32 +// CHECK: ^bb0([[VAR_arg6_1_:%.+]]: f32, [[VAR_arg7_1_:%.+]]: f32): +// CHECK: [[VAR_23_2_:%.+]] = arith.addf [[VAR_arg6_1_]], [[VAR_arg7_1_]] : f32 +// CHECK: tt.reduce.return [[VAR_23_2_]] : f32 // CHECK: }) : (tensor<256x256xf32>) -> tensor<256xf32> // CHECK-DAG: [[VAR_9_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_3_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> // CHECK-DAG: [[VAR_10_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index @@ -131,15 +135,17 @@ module { // CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_10_]], [[CST_256_]] : index // CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index // CHECK: [[VAR_13_:%.+]] = arith.minsi [[VAR_11_]], [[VAR_12_]] : index -// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_13_]], [[VAR_10_]] : index -// CHECK: "tts.store"([[VAR_9_]], [[VAR_7_]], [[VAR_14_]]) <{static_mask_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xf32>, index) -> () -// CHECK-DAG: [[VAR_15_:%.+]] = tts.make_tptr [[PARAM_3_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> -// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index +// CHECK: [[VAR_14_:%.+]] = arith.maxsi [[VAR_13_]], [[VAR_10_]] : index +// CHECK: [[VAR_15_:%.+]] = arith.subi [[VAR_14_]], [[VAR_10_]] : index +// CHECK: "tts.store"([[VAR_9_]], [[VAR_7_]], [[VAR_15_]]) <{static_mask_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xf32>, index) -> () +// CHECK-DAG: [[VAR_16_:%.+]] = tts.make_tptr [[PARAM_3_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_2_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_16_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_19_:%.+]] = arith.minsi [[VAR_17_]], [[VAR_18_]] : index -// CHECK: [[VAR_20_:%.+]] = arith.subi [[VAR_19_]], [[VAR_16_]] : index -// CHECK: "tts.store"([[VAR_15_]], [[VAR_8_]], [[VAR_20_]]) <{static_mask_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xf32>, index) -> () +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addi [[VAR_17_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_20_:%.+]] = arith.minsi [[VAR_18_]], [[VAR_19_]] : index +// CHECK: [[VAR_21_:%.+]] = arith.maxsi [[VAR_20_]], [[VAR_17_]] : index +// CHECK: [[VAR_22_:%.+]] = arith.subi [[VAR_21_]], [[VAR_17_]] : index +// CHECK: "tts.store"([[VAR_16_]], [[VAR_8_]], [[VAR_22_]]) <{static_mask_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xf32>, index) -> () // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/TritonToStructured/kernel-05-layer-norm-fwd.mlir b/test/Conversion/TritonToStructured/kernel-05-layer-norm-fwd.mlir index 3bc80f4f..ac6f22df 100644 --- a/test/Conversion/TritonToStructured/kernel-05-layer-norm-fwd.mlir +++ b/test/Conversion/TritonToStructured/kernel-05-layer-norm-fwd.mlir @@ -89,33 +89,34 @@ module { } // CHECK: tt.func public @_layer_norm_fwd_fused_0123456789([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: !tt.ptr, [[PARAM_4_:%.+]]: !tt.ptr, [[PARAM_5_:%.+]]: !tt.ptr, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: f32) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : tensor<256xf32> +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : tensor<256xf32> // CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 // CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_6_]] : i32 // CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index // CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : i32 to index -// CHECK-DAG: [[VAR_4_:%.+]] = scf.for [[VAR_arg9_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] iter_args([[VAR_arg10_:%.+]] = [[VAR_cst_]]) -> (tensor<256xf32>) : i32 { +// CHECK-DAG: [[VAR_4_:%.+]] = scf.for [[VAR_arg9_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] iter_args([[VAR_arg10_:%.+]] = [[VAR_cst_0_]]) -> (tensor<256xf32>) : i32 { // CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index -// CHECK: [[VAR_22_:%.+]] = arith.addi [[VAR_2_]], [[VAR_2_]]1 : index +// CHECK: [[VAR_22_:%.+]] = arith.addi [[VAR_2_]], [[VAR_21_]] : index // CHECK-DAG: [[VAR_23_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_22_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> // CHECK-DAG: [[VAR_24_:%.+]] = arith.index_cast [[VAR_arg9_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_25_:%.+]] = arith.addi [[VAR_24_]], [[CST_256_]] : index // CHECK-DAG: [[VAR_26_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index // CHECK: [[VAR_27_:%.+]] = arith.minsi [[VAR_25_]], [[VAR_26_]] : index -// CHECK: [[VAR_28_:%.+]] = arith.subi [[VAR_27_]], [[VAR_24_]] : index -// CHECK: [[VAR_29_:%.+]] = "tts.load"([[VAR_23_]], [[VAR_28_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index, f32) -> tensor<256xf32> -// CHECK: [[VAR_30_:%.+]] = arith.addf [[VAR_arg10_]], [[VAR_29_]] : tensor<256xf32> -// CHECK: scf.yield [[VAR_30_]] : tensor<256xf32> +// CHECK: [[VAR_28_:%.+]] = arith.maxsi [[VAR_27_]], [[VAR_24_]] : index +// CHECK: [[VAR_29_:%.+]] = arith.subi [[VAR_28_]], [[VAR_24_]] : index +// CHECK: [[VAR_30_:%.+]] = "tts.load"([[VAR_23_]], [[VAR_29_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index, f32) -> tensor<256xf32> +// CHECK: [[VAR_31_:%.+]] = arith.addf [[VAR_arg10_]], [[VAR_30_]] : tensor<256xf32> +// CHECK: scf.yield [[VAR_31_]] : tensor<256xf32> // CHECK: } // CHECK: [[VAR_5_:%.+]] = "tt.reduce"([[VAR_4_]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0([[VAR_arg9_]]: f32, [[VAR_arg10_]]: f32): -// CHECK: [[VAR_21_1_:%.+]] = arith.addf [[VAR_arg9_]], [[VAR_arg10_]] : f32 +// CHECK: ^bb0([[VAR_arg9_1_:%.+]]: f32, [[VAR_arg10_1_:%.+]]: f32): +// CHECK: [[VAR_21_1_:%.+]] = arith.addf [[VAR_arg9_1_]], [[VAR_arg10_1_]] : f32 // CHECK: tt.reduce.return [[VAR_21_1_]] : f32 // CHECK: }) : (tensor<256xf32>) -> f32 // CHECK: [[VAR_6_:%.+]] = arith.sitofp [[PARAM_7_]] : i32 to f32 @@ -124,29 +125,30 @@ module { // CHECK-DAG: [[VAR_9_:%.+]] = tt.splat [[PARAM_7_]] : i32 -> tensor<256xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_10_:%.+]] = tt.splat [[VAR_7_]] : f32 -> tensor<256xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = scf.for [[VAR_arg9_1_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] iter_args([[VAR_arg10_1_:%.+]] = [[VAR_cst_]]) -> (tensor<256xf32>) : i32 { -// CHECK-DAG: [[VAR_21_2_:%.+]] = tt.splat [[VAR_arg9_1_]] : i32 -> tensor<256xi32> +// CHECK-DAG: [[VAR_11_:%.+]] = scf.for [[VAR_arg9_2_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] iter_args([[VAR_arg10_2_:%.+]] = [[VAR_cst_0_]]) -> (tensor<256xf32>) : i32 { +// CHECK-DAG: [[VAR_21_2_:%.+]] = tt.splat [[VAR_arg9_2_]] : i32 -> tensor<256xi32> // CHECK: [[VAR_22_1_:%.+]] = arith.addi [[VAR_21_2_]], [[VAR_8_]] : tensor<256xi32> // CHECK-DAG: [[VAR_23_1_:%.+]] = arith.cmpi slt, [[VAR_22_1_]], [[VAR_9_]] : tensor<256xi32> -// CHECK-DAG: [[VAR_24_1_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index -// CHECK: [[VAR_25_1_:%.+]] = arith.addi [[VAR_2_]], [[VAR_2_]]4 : index +// CHECK-DAG: [[VAR_24_1_:%.+]] = arith.index_cast [[VAR_arg9_2_]] : i32 to index +// CHECK: [[VAR_25_1_:%.+]] = arith.addi [[VAR_2_]], [[VAR_24_1_]] : index // CHECK-DAG: [[VAR_26_1_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_25_1_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> -// CHECK-DAG: [[VAR_27_1_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-DAG: [[VAR_27_1_:%.+]] = arith.index_cast [[VAR_arg9_2_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_28_1_:%.+]] = arith.addi [[VAR_27_1_]], [[CST_256_]] : index // CHECK-DAG: [[VAR_29_1_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index // CHECK: [[VAR_30_1_:%.+]] = arith.minsi [[VAR_28_1_]], [[VAR_29_1_]] : index -// CHECK: [[VAR_31_:%.+]] = arith.subi [[VAR_30_1_]], [[VAR_27_1_]] : index -// CHECK: [[VAR_32_:%.+]] = "tts.load"([[VAR_26_1_]], [[VAR_31_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index, f32) -> tensor<256xf32> -// CHECK: [[VAR_33_:%.+]] = arith.subf [[VAR_32_]], [[VAR_10_]] : tensor<256xf32> -// CHECK: [[VAR_34_:%.+]] = arith.select [[VAR_23_1_]], [[VAR_33_]], [[VAR_cst_]] : tensor<256xi1>, tensor<256xf32> -// CHECK: [[VAR_35_:%.+]] = arith.mulf [[VAR_34_]], [[VAR_34_]] : tensor<256xf32> -// CHECK: [[VAR_36_:%.+]] = arith.addf [[VAR_arg10_1_]], [[VAR_35_]] : tensor<256xf32> -// CHECK: scf.yield [[VAR_36_]] : tensor<256xf32> +// CHECK: [[VAR_31_1_:%.+]] = arith.maxsi [[VAR_30_1_]], [[VAR_27_1_]] : index +// CHECK: [[VAR_32_:%.+]] = arith.subi [[VAR_31_1_]], [[VAR_27_1_]] : index +// CHECK: [[VAR_33_:%.+]] = "tts.load"([[VAR_26_1_]], [[VAR_32_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index, f32) -> tensor<256xf32> +// CHECK: [[VAR_34_:%.+]] = arith.subf [[VAR_33_]], [[VAR_10_]] : tensor<256xf32> +// CHECK: [[VAR_35_:%.+]] = arith.select [[VAR_23_1_]], [[VAR_34_]], [[VAR_cst_0_]] : tensor<256xi1>, tensor<256xf32> +// CHECK: [[VAR_36_:%.+]] = arith.mulf [[VAR_35_]], [[VAR_35_]] : tensor<256xf32> +// CHECK: [[VAR_37_:%.+]] = arith.addf [[VAR_arg10_2_]], [[VAR_36_]] : tensor<256xf32> +// CHECK: scf.yield [[VAR_37_]] : tensor<256xf32> // CHECK: } // CHECK: [[VAR_12_:%.+]] = "tt.reduce"([[VAR_11_]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0([[VAR_arg9_1_]]: f32, [[VAR_arg10_1_]]: f32): -// CHECK: [[VAR_21_3_:%.+]] = arith.addf [[VAR_arg9_1_]], [[VAR_arg10_1_]] : f32 +// CHECK: ^bb0([[VAR_arg9_2_:%.+]]: f32, [[VAR_arg10_2_:%.+]]: f32): +// CHECK: [[VAR_21_3_:%.+]] = arith.addf [[VAR_arg9_2_]], [[VAR_arg10_2_]] : f32 // CHECK: tt.reduce.return [[VAR_21_3_]] : f32 // CHECK: }) : (tensor<256xf32>) -> f32 // CHECK: [[VAR_13_:%.+]] = arith.divf [[VAR_12_]], [[VAR_6_]] : f32 @@ -159,50 +161,54 @@ module { // CHECK: tt.store [[VAR_18_]], [[VAR_16_]] : !tt.ptr // CHECK-DAG: [[VAR_19_:%.+]] = tt.splat [[VAR_7_]] : f32 -> tensor<256xf32> // CHECK-DAG: [[VAR_20_:%.+]] = tt.splat [[VAR_16_]] : f32 -> tensor<256xf32> -// CHECK: scf.for [[VAR_arg9_1_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] : i32 { -// CHECK: [[VAR_21_4_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK: scf.for [[VAR_arg9_2_1_:%.+]] = [[CST_0_]] to [[PARAM_7_]] step [[CST_256_1_]] : i32 { +// CHECK: [[VAR_21_4_:%.+]] = arith.index_cast [[VAR_arg9_2_1_]] : i32 to index // CHECK-DAG: [[VAR_22_2_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_21_4_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> -// CHECK-DAG: [[VAR_23_2_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-DAG: [[VAR_23_2_:%.+]] = arith.index_cast [[VAR_arg9_2_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_24_2_:%.+]] = arith.addi [[VAR_23_2_]], [[CST_256_]] : index // CHECK-DAG: [[VAR_25_2_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index // CHECK: [[VAR_26_2_:%.+]] = arith.minsi [[VAR_24_2_]], [[VAR_25_2_]] : index -// CHECK: [[VAR_27_2_:%.+]] = arith.subi [[VAR_26_2_]], [[VAR_23_2_]] : index -// CHECK-DAG: [[VAR_28_2_:%.+]] = "tts.load"([[VAR_22_2_]], [[VAR_27_2_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index) -> tensor<256xf32> -// CHECK-DAG: [[VAR_29_2_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK: [[VAR_27_2_:%.+]] = arith.maxsi [[VAR_26_2_]], [[VAR_23_2_]] : index +// CHECK: [[VAR_28_2_:%.+]] = arith.subi [[VAR_27_2_]], [[VAR_23_2_]] : index +// CHECK-DAG: [[VAR_29_2_:%.+]] = "tts.load"([[VAR_22_2_]], [[VAR_28_2_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index) -> tensor<256xf32> +// CHECK-DAG: [[VAR_30_2_:%.+]] = arith.index_cast [[VAR_arg9_2_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_30_2_:%.+]] = tts.make_tptr [[PARAM_3_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_29_2_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> -// CHECK-DAG: [[VAR_31_1_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-DAG: [[VAR_31_2_:%.+]] = tts.make_tptr [[PARAM_3_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_30_2_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_32_1_:%.+]] = arith.index_cast [[VAR_arg9_2_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_32_1_:%.+]] = arith.addi [[VAR_31_1_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_33_1_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK: [[VAR_34_1_:%.+]] = arith.minsi [[VAR_32_1_]], [[VAR_33_1_]] : index -// CHECK: [[VAR_35_1_:%.+]] = arith.subi [[VAR_34_1_]], [[VAR_31_1_]] : index -// CHECK-DAG: [[VAR_36_1_:%.+]] = "tts.load"([[VAR_30_2_]], [[VAR_35_1_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index) -> tensor<256xf32> -// CHECK-DAG: [[VAR_37_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index -// CHECK: [[VAR_38_:%.+]] = arith.addi [[VAR_2_]], [[VAR_37_]] : index -// CHECK-DAG: [[VAR_39_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_38_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> -// CHECK-DAG: [[VAR_40_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-DAG: [[VAR_33_1_:%.+]] = arith.addi [[VAR_32_1_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_34_1_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_35_1_:%.+]] = arith.minsi [[VAR_33_1_]], [[VAR_34_1_]] : index +// CHECK: [[VAR_36_1_:%.+]] = arith.maxsi [[VAR_35_1_]], [[VAR_32_1_]] : index +// CHECK: [[VAR_37_1_:%.+]] = arith.subi [[VAR_36_1_]], [[VAR_32_1_]] : index +// CHECK-DAG: [[VAR_38_:%.+]] = "tts.load"([[VAR_31_2_]], [[VAR_37_1_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index) -> tensor<256xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.index_cast [[VAR_arg9_2_1_]] : i32 to index +// CHECK: [[VAR_40_:%.+]] = arith.addi [[VAR_2_]], [[VAR_39_]] : index +// CHECK-DAG: [[VAR_41_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_40_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.index_cast [[VAR_arg9_2_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_:%.+]] = arith.addi [[VAR_40_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_42_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK: [[VAR_43_:%.+]] = arith.minsi [[VAR_41_]], [[VAR_42_]] : index -// CHECK: [[VAR_44_:%.+]] = arith.subi [[VAR_43_]], [[VAR_40_]] : index -// CHECK: [[VAR_45_:%.+]] = "tts.load"([[VAR_39_]], [[VAR_44_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index, f32) -> tensor<256xf32> -// CHECK: [[VAR_46_:%.+]] = arith.subf [[VAR_45_]], [[VAR_19_]] : tensor<256xf32> -// CHECK: [[VAR_47_:%.+]] = arith.mulf [[VAR_46_]], [[VAR_20_]] : tensor<256xf32> -// CHECK: [[VAR_48_:%.+]] = arith.mulf [[VAR_47_]], [[VAR_28_2_]] : tensor<256xf32> -// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_36_1_]] : tensor<256xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index -// CHECK: [[VAR_51_:%.+]] = arith.addi [[VAR_3_]], [[VAR_50_]] : index -// CHECK-DAG: [[VAR_52_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_51_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> -// CHECK-DAG: [[VAR_53_:%.+]] = arith.index_cast [[VAR_arg9_1_]] : i32 to index +// CHECK-DAG: [[VAR_43_:%.+]] = arith.addi [[VAR_42_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_44_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_45_:%.+]] = arith.minsi [[VAR_43_]], [[VAR_44_]] : index +// CHECK: [[VAR_46_:%.+]] = arith.maxsi [[VAR_45_]], [[VAR_42_]] : index +// CHECK: [[VAR_47_:%.+]] = arith.subi [[VAR_46_]], [[VAR_42_]] : index +// CHECK: [[VAR_48_:%.+]] = "tts.load"([[VAR_41_]], [[VAR_47_]], [[CST_0_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<256x!tt.ptr>, index, f32) -> tensor<256xf32> +// CHECK: [[VAR_49_:%.+]] = arith.subf [[VAR_48_]], [[VAR_19_]] : tensor<256xf32> +// CHECK: [[VAR_50_:%.+]] = arith.mulf [[VAR_49_]], [[VAR_20_]] : tensor<256xf32> +// CHECK: [[VAR_51_:%.+]] = arith.mulf [[VAR_50_]], [[VAR_29_2_]] : tensor<256xf32> +// CHECK-DAG: [[VAR_52_:%.+]] = arith.addf [[VAR_51_]], [[VAR_38_]] : tensor<256xf32> +// CHECK-DAG: [[VAR_53_:%.+]] = arith.index_cast [[VAR_arg9_2_1_]] : i32 to index +// CHECK: [[VAR_54_:%.+]] = arith.addi [[VAR_3_]], [[VAR_53_]] : index +// CHECK-DAG: [[VAR_55_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [256], strides: [1], offsets: {{.}}[[VAR_54_]]{{.}}, shape: [0], order: [] : to tensor<256x!tt.ptr> +// CHECK-DAG: [[VAR_56_:%.+]] = arith.index_cast [[VAR_arg9_2_1_]] : i32 to index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_54_:%.+]] = arith.addi [[VAR_53_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_55_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK: [[VAR_56_:%.+]] = arith.minsi [[VAR_54_]], [[VAR_55_]] : index -// CHECK: [[VAR_57_:%.+]] = arith.subi [[VAR_56_]], [[VAR_53_]] : index -// CHECK: "tts.store"([[VAR_52_]], [[VAR_49_]], [[VAR_57_]]) <{static_mask_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xf32>, index) -> () +// CHECK-DAG: [[VAR_57_:%.+]] = arith.addi [[VAR_56_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_58_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index +// CHECK: [[VAR_59_:%.+]] = arith.minsi [[VAR_57_]], [[VAR_58_]] : index +// CHECK: [[VAR_60_:%.+]] = arith.maxsi [[VAR_59_]], [[VAR_56_]] : index +// CHECK: [[VAR_61_:%.+]] = arith.subi [[VAR_60_]], [[VAR_56_]] : index +// CHECK: "tts.store"([[VAR_55_]], [[VAR_52_]], [[VAR_61_]]) <{static_mask_dims = array}> : (tensor<256x!tt.ptr>, tensor<256xf32>, index) -> () // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/TritonToStructured/masked_ldst_1d.mlir b/test/Conversion/TritonToStructured/masked_ldst_1d.mlir index accb97f0..9ed8f81d 100644 --- a/test/Conversion/TritonToStructured/masked_ldst_1d.mlir +++ b/test/Conversion/TritonToStructured/masked_ldst_1d.mlir @@ -23,14 +23,17 @@ module { // CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32) { // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<128x!tt.ptr> // CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<128x!tt.ptr> // CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_128_]] : index -// CHECK-DAG: [[VAR_4_:%.+]] = "tts.load"([[VAR_0_]], [[VAR_3_]], [[CST_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x!tt.ptr>, index, bf16) -> tensor<128xbf16> -// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK: [[VAR_6_:%.+]] = arith.minsi [[VAR_5_]], [[CST_128_]] : index -// CHECK: "tts.store"([[VAR_1_]], [[VAR_4_]], [[VAR_6_]]) <{static_mask_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xbf16>, index) -> () +// CHECK: [[VAR_4_:%.+]] = arith.maxsi [[VAR_3_]], [[CST_0_1_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = "tts.load"([[VAR_0_]], [[VAR_4_]], [[CST_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x!tt.ptr>, index, bf16) -> tensor<128xbf16> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_7_:%.+]] = arith.minsi [[VAR_6_]], [[CST_128_]] : index +// CHECK: [[VAR_8_:%.+]] = arith.maxsi [[VAR_7_]], [[CST_0_1_]] : index +// CHECK: "tts.store"([[VAR_1_]], [[VAR_5_]], [[VAR_8_]]) <{static_mask_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xbf16>, index) -> () // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/TritonToStructured/masked_ldst_2d.mlir b/test/Conversion/TritonToStructured/masked_ldst_2d.mlir index b79d6958..642b591b 100644 --- a/test/Conversion/TritonToStructured/masked_ldst_2d.mlir +++ b/test/Conversion/TritonToStructured/masked_ldst_2d.mlir @@ -63,35 +63,39 @@ module { } // CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32) { -// CHECK-DAG: [[CST_3072_:%.+]] = arith.constant 3072 : index -// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[CST_259_:%.+]] = arith.constant 259 : index // CHECK-DAG: [[CST_130_:%.+]] = arith.constant 130 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF80 : bf16 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_1024_:%.+]] = arith.constant 1024 : index +// CHECK-DAG: [[CST_3072_:%.+]] = arith.constant 3072 : index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128, 256], strides: [1, [[CST_1024_]]{{.}}, offsets: {{.}}[[CST_2_]], [[CST_3072_]]{{.}}, shape: [0, 0], order: [] : to tensor<128x256x!tt.ptr> // CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128, 256], strides: [1, [[CST_1024_]]{{.}}, offsets: {{.}}[[CST_2_]], [[CST_3072_]]{{.}}, shape: [0, 0], order: [] : to tensor<128x256x!tt.ptr> // CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_130_]] : index -// CHECK-DAG: [[VAR_4_:%.+]] = arith.subi [[VAR_3_]], [[CST_2_]] : index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_6_:%.+]] = arith.minsi [[VAR_5_]], [[CST_259_]] : index -// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_6_]], [[CST_3_]] : index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.minsi [[VAR_4_]], [[CST_128_]] : index -// CHECK: [[VAR_9_:%.+]] = arith.minsi [[VAR_7_]], [[CST_256_]] : index -// CHECK-DAG: [[VAR_10_:%.+]] = "tts.load"([[VAR_0_]], [[VAR_8_]], [[VAR_9_]], [[CST_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x256x!tt.ptr>, index, index, bf16) -> tensor<128x256xbf16> -// CHECK-DAG: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK: [[VAR_12_:%.+]] = arith.minsi [[VAR_11_]], [[CST_130_]] : index -// CHECK-DAG: [[VAR_13_:%.+]] = arith.subi [[VAR_12_]], [[CST_2_]] : index -// CHECK-DAG: [[VAR_14_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_15_:%.+]] = arith.minsi [[VAR_14_]], [[CST_259_]] : index -// CHECK-DAG: [[VAR_16_:%.+]] = arith.subi [[VAR_15_]], [[CST_3_]] : index -// CHECK-DAG: [[VAR_17_:%.+]] = arith.minsi [[VAR_13_]], [[CST_128_]] : index -// CHECK: [[VAR_18_:%.+]] = arith.minsi [[VAR_16_]], [[CST_256_]] : index -// CHECK: "tts.store"([[VAR_1_]], [[VAR_1_]]0, [[VAR_1_]]7, [[VAR_1_]]8) <{static_mask_dims = array}> : (tensor<128x256x!tt.ptr>, tensor<128x256xbf16>, index, index) -> () +// CHECK: [[VAR_4_:%.+]] = arith.maxsi [[VAR_3_]], [[CST_2_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[CST_2_]] : index +// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_7_:%.+]] = arith.minsi [[VAR_6_]], [[CST_259_]] : index +// CHECK: [[VAR_8_:%.+]] = arith.maxsi [[VAR_7_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_9_:%.+]] = arith.subi [[VAR_8_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_10_:%.+]] = arith.minsi [[VAR_5_]], [[CST_128_]] : index +// CHECK: [[VAR_11_:%.+]] = arith.minsi [[VAR_9_]], [[CST_256_]] : index +// CHECK-DAG: [[VAR_12_:%.+]] = "tts.load"([[VAR_0_]], [[VAR_10_]], [[VAR_11_]], [[CST_0_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x256x!tt.ptr>, index, index, bf16) -> tensor<128x256xbf16> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_14_:%.+]] = arith.minsi [[VAR_13_]], [[CST_130_]] : index +// CHECK: [[VAR_15_:%.+]] = arith.maxsi [[VAR_14_]], [[CST_2_]] : index +// CHECK-DAG: [[VAR_16_:%.+]] = arith.subi [[VAR_15_]], [[CST_2_]] : index +// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK: [[VAR_18_:%.+]] = arith.minsi [[VAR_17_]], [[CST_259_]] : index +// CHECK: [[VAR_19_:%.+]] = arith.maxsi [[VAR_18_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_20_:%.+]] = arith.subi [[VAR_19_]], [[CST_3_]] : index +// CHECK-DAG: [[VAR_21_:%.+]] = arith.minsi [[VAR_16_]], [[CST_128_]] : index +// CHECK: [[VAR_22_:%.+]] = arith.minsi [[VAR_20_]], [[CST_256_]] : index +// CHECK: "tts.store"([[VAR_1_]], [[VAR_12_]], [[VAR_21_]], [[VAR_22_]]) <{static_mask_dims = array}> : (tensor<128x256x!tt.ptr>, tensor<128x256xbf16>, index, index) -> () // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/TritonToStructured/masked_ldst_sitofp_other.mlir b/test/Conversion/TritonToStructured/masked_ldst_sitofp_other.mlir index f485d511..0afb1ed4 100644 --- a/test/Conversion/TritonToStructured/masked_ldst_sitofp_other.mlir +++ b/test/Conversion/TritonToStructured/masked_ldst_sitofp_other.mlir @@ -25,14 +25,17 @@ module { // CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32) { // CHECK-DAG: [[CST_7_dot_000000_:%.+]] = arith.constant 7.000000e+00 : bf16 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index // CHECK-DAG: [[VAR_0_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<128x!tt.ptr> // CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [128], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<128x!tt.ptr> // CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_3_:%.+]] = arith.minsi [[VAR_2_]], [[CST_128_]] : index -// CHECK-DAG: [[VAR_4_:%.+]] = "tts.load"([[VAR_0_]], [[VAR_3_]], [[CST_7_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x!tt.ptr>, index, bf16) -> tensor<128xbf16> -// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK: [[VAR_6_:%.+]] = arith.minsi [[VAR_5_]], [[CST_128_]] : index -// CHECK: "tts.store"([[VAR_1_]], [[VAR_4_]], [[VAR_6_]]) <{static_mask_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xbf16>, index) -> () +// CHECK: [[VAR_4_:%.+]] = arith.maxsi [[VAR_3_]], [[CST_0_]] : index +// CHECK-DAG: [[VAR_5_:%.+]] = "tts.load"([[VAR_0_]], [[VAR_4_]], [[CST_7_dot_000000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<128x!tt.ptr>, index, bf16) -> tensor<128xbf16> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_7_:%.+]] = arith.minsi [[VAR_6_]], [[CST_128_]] : index +// CHECK: [[VAR_8_:%.+]] = arith.maxsi [[VAR_7_]], [[CST_0_]] : index +// CHECK: "tts.store"([[VAR_1_]], [[VAR_5_]], [[VAR_8_]]) <{static_mask_dims = array}> : (tensor<128x!tt.ptr>, tensor<128xbf16>, index) -> () // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/TritonToStructured/sign_extend_i32_to_i64.mlir b/test/Conversion/TritonToStructured/sign_extend_i32_to_i64.mlir index b3899a95..27d6ab35 100644 --- a/test/Conversion/TritonToStructured/sign_extend_i32_to_i64.mlir +++ b/test/Conversion/TritonToStructured/sign_extend_i32_to_i64.mlir @@ -22,18 +22,19 @@ module { } } -// CHECK: tt.func public @sign_extend([[arg0_:.+]]: !tt.ptr, [[arg1_:.+]]: !tt.ptr, [[arg2_:.+]]: !tt.ptr, [[arg3_:.+]]: i32) attributes {noinline = false} { +// CHECK: tt.func public @sign_extend([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32) attributes {noinline = false} { // CHECK-DAG: [[CST_1_dot_100000_:%.+]] = arith.constant 1.100000e+01 : f32 // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[LOAD_arg0_MEM_:%.+]] = tt.load [[arg0_]] : !tt.ptr -// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[LOAD_arg0_MEM_]] : i32 to index -// CHECK-DAG: [[VAR_2_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4], strides: [1], offsets: {{.}}[[VAR_1_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = tt.load [[PARAM_0_]] : !tt.ptr +// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[LOAD_PARAM_0_MEM_]] : i32 to index +// CHECK-DAG: [[VAR_2_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4], strides: [1], offsets: {{.}}[[VAR_1_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> // CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[VAR_1_]], [[CST_4_]] : index -// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[arg3_]] : i32 to index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index // CHECK: [[VAR_5_:%.+]] = arith.minsi [[VAR_3_]], [[VAR_4_]] : index -// CHECK: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_1_]] : index -// CHECK-DAG: [[VAR_7_:%.+]] = "tts.load"([[VAR_2_]], [[VAR_6_]], [[CST_1_dot_100000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>, index, f32) -> tensor<4xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = tts.make_tptr [[arg2_]] to sizes: [4], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<4x!tt.ptr> -// CHECK: "tts.store"([[VAR_8_]], [[VAR_7_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK: [[VAR_6_:%.+]] = arith.maxsi [[VAR_5_]], [[VAR_1_]] : index +// CHECK: [[VAR_7_:%.+]] = arith.subi [[VAR_6_]], [[VAR_1_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = "tts.load"([[VAR_2_]], [[VAR_7_]], [[CST_1_dot_100000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>, index, f32) -> tensor<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = tts.make_tptr [[PARAM_2_]] to sizes: [4], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK: "tts.store"([[VAR_9_]], [[VAR_8_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir b/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir index 3c96c2d3..b904016f 100644 --- a/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir +++ b/test/Conversion/TritonToStructured/tensor_indices_loop_iterarg_with_masks.mlir @@ -26,7 +26,7 @@ module { } } -// CHECK: tt.func public @addptr_with_masks([[arg0_:.+]]: !tt.ptr, [[arg1_:.+]]: !tt.ptr, [[arg2_:.+]]: i32) attributes {noinline = false} { +// CHECK: tt.func public @addptr_with_masks([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32) attributes {noinline = false} { // CHECK-DAG: [[CST_minus_1_dot_100000_:%.+]] = arith.constant -1.100000e+01 : f32 // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 @@ -36,16 +36,17 @@ module { // CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_0_:%.+]]:2 = scf.for [[VAR_arg3_:%.+]] = [[CST_0_]] to [[CST_4_1_]] step [[CST_1_]] iter_args([[VAR_arg4_:%.+]] = [[CST_0_1_]], [[VAR_arg5_:%.+]] = [[CST_0_1_]]) -> (index, index) : i32 { -// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[arg0_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg4_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK-DAG: [[VAR_1_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg4_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> // CHECK-DAG: [[VAR_2_:%.+]] = arith.addi [[VAR_arg4_]], [[CST_4_]] : index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[arg2_]] : i32 to index +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_4_:%.+]] = arith.minsi [[VAR_2_]], [[VAR_3_]] : index -// CHECK: [[VAR_5_:%.+]] = arith.subi [[VAR_4_]], [[VAR_arg4_]] : index -// CHECK-DAG: [[VAR_6_:%.+]] = "tts.load"([[VAR_1_]], [[VAR_5_]], [[CST_minus_1_dot_100000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>, index, f32) -> tensor<4xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = tts.make_tptr [[arg1_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg5_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> -// CHECK: "tts.store"([[VAR_7_]], [[VAR_6_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () -// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_4_]] : index -// CHECK: scf.yield [[VAR_2_]], [[VAR_8_]] : index, index +// CHECK: [[VAR_5_:%.+]] = arith.maxsi [[VAR_4_]], [[VAR_arg4_]] : index +// CHECK: [[VAR_6_:%.+]] = arith.subi [[VAR_5_]], [[VAR_arg4_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = "tts.load"([[VAR_1_]], [[VAR_6_]], [[CST_minus_1_dot_100000_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<4x!tt.ptr>, index, f32) -> tensor<4xf32> +// CHECK-DAG: [[VAR_8_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [4], strides: {{.}}[[CST_1_1_]]{{.}}, offsets: {{.}}[[VAR_arg5_]]{{.}}, shape: [0], order: [] : to tensor<4x!tt.ptr> +// CHECK: "tts.store"([[VAR_8_]], [[VAR_7_]]) <{static_mask_dims = array}> : (tensor<4x!tt.ptr>, tensor<4xf32>) -> () +// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_4_]] : index +// CHECK: scf.yield [[VAR_2_]], [[VAR_9_]] : index, index // CHECK: } // CHECK: tt.return // CHECK: }