From 3bb634f65836f629f3fe31beff725f7101e48315 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Tue, 22 Oct 2024 21:00:53 -0700 Subject: [PATCH 1/5] Purge condenseMemrefDataReorderingToAIRDma from -air-dma-to-copy --- mlir/lib/Conversion/ConvertToAIRPass.cpp | 450 ------------------ .../condense_memref_ops_to_air_memcpy.mlir | 124 ----- 2 files changed, 574 deletions(-) delete mode 100644 mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index 597f4a46a..86fd9e160 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -1100,374 +1100,6 @@ class ScfForallToLaunchConversion : public OpRewritePattern { bool generateSegment; }; -/// Build a strided memref type by applying `permutationMap` tp `memRefType`. -static MemRefType inferTransposeResultType(MemRefType memRefType, - AffineMap permutationMap) { - auto rank = memRefType.getRank(); - auto originalSizes = memRefType.getShape(); - auto [originalStrides, offset] = getStridesAndOffset(memRefType); - assert(originalStrides.size() == static_cast(rank)); - - // Compute permuted sizes and strides. - SmallVector sizes(rank, 0); - SmallVector strides(rank, 1); - for (const auto &en : llvm::enumerate(permutationMap.getResults())) { - unsigned position = cast(en.value()).getPosition(); - sizes[en.index()] = originalSizes[position]; - strides[en.index()] = originalStrides[position]; - } - - return MemRefType::Builder(memRefType) - .setShape(sizes) - .setLayout( - StridedLayoutAttr::get(memRefType.getContext(), offset, strides)); -} - -static SmallVector extractStridesFromMemrefType(MemRefType memrefTy, - OpBuilder &builder) { - // get the strides and offsets from the memref type - SmallVector strides; - int64_t offset; - SmallVector layout_strides; - auto successStrides = getStridesAndOffset(memrefTy, layout_strides, offset); - if (failed(successStrides)) { - llvm::outs() << "Failed to get strides\n"; - return strides; - } - - for (auto s : layout_strides) - strides.push_back( - builder.create(builder.getUnknownLoc(), s)); - - return strides; -} - -static SmallVector extractSizesFromMemrefType(MemRefType memrefTy, - OpBuilder &builder) { - SmallVector sizes; - for (auto s : memrefTy.getShape()) - sizes.push_back( - builder.create(builder.getUnknownLoc(), s)); - return sizes; -} - -static void extractOffsetsFromSubview(memref::SubViewOp subview, - OpBuilder &builder, - SmallVector &offsets) { - auto subview_offsets = subview.getOffsets().begin(); - auto static_offsets = subview.getStaticOffsets(); - auto loc = subview.getLoc(); - - for (auto o : static_offsets) { - if (o >= 0) - offsets.push_back(builder.create(loc, o)); - else - offsets.push_back(*subview_offsets++); - } -} - -static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder, - SmallVector &offsets, - SmallVector &sizes, - SmallVector &strides, - MemRefType memref) { - // Increase vector sizes up to memref size. When offsets, sizes and strides - // are all empty, then it implies that the whole memref is accessed in the - // default order. - auto max_dim_size = - std::max(std::max(offsets.size(), sizes.size()), strides.size()); - auto target_dim_size = std::max(max_dim_size, (size_t)memref.getRank()); - if (max_dim_size && offsets.size() < target_dim_size) { - for (unsigned i = offsets.size(); i < target_dim_size; i++) { - offsets.insert(offsets.begin(), builder.create( - builder.getUnknownLoc(), 0)); - } - } - if (max_dim_size && sizes.size() < target_dim_size) { - for (unsigned i = sizes.size(); i < target_dim_size; i++) { - sizes.insert(sizes.begin(), builder.create( - builder.getUnknownLoc(), 1)); - } - } - int memref_size = 1; - for (auto size : memref.getShape()) - memref_size *= size; - if (max_dim_size && strides.size() < target_dim_size) { - for (unsigned i = strides.size(); i < target_dim_size; i++) { - strides.insert(strides.begin(), - builder.create( - builder.getUnknownLoc(), memref_size)); - } - } - - // Reduce highest dimensions if more than memref size - while (strides.size() > target_dim_size && getConstantIntValue(strides[0]) && - *getConstantIntValue(strides[0]) == memref_size) { - strides.erase(strides.begin()); - } - while (sizes.size() > target_dim_size && getConstantIntValue(sizes[0]) && - *getConstantIntValue(sizes[0]) == 1) { - sizes.erase(sizes.begin()); - } - while (offsets.size() > std::min(sizes.size(), strides.size()) && - getConstantIntValue(offsets[0]) && - *getConstantIntValue(offsets[0]) == 0) { - offsets.erase(offsets.begin()); - } - - if (offsets.size() != sizes.size() || sizes.size() != strides.size()) - return failure(); - - return success(); -} - -static LogicalResult condenseMemrefDataReorderingToAIRDma( - air::DmaMemcpyNdOp dmaOp, std::vector src_ancestor_memref_ops, - std::vector dst_ancestor_memref_ops) { - OpBuilder rewriter(dmaOp); - auto src = dmaOp.getSrcMemref(); - auto dst = dmaOp.getDstMemref(); - auto loc = dmaOp->getLoc(); - - // It must already be a memref - auto src_type = llvm::dyn_cast(src.getType()); - auto dst_type = llvm::dyn_cast(dst.getType()); - if (!src_type) - return failure(); - if (!(src_type.hasStaticShape() || dst_type.hasStaticShape())) - return failure(); - - // Revert the vector of memref ops, as it was built with push_back. - std::reverse(src_ancestor_memref_ops.begin(), src_ancestor_memref_ops.end()); - std::reverse(dst_ancestor_memref_ops.begin(), dst_ancestor_memref_ops.end()); - - SmallVector src_offsets, dst_offsets; - SmallVector src_strides, dst_strides; - SmallVector src_sizes, dst_sizes; - SmallVector empty; - auto constZero = rewriter.create(loc, 0); - - MemRefType src_memref_ty; - if (!src_ancestor_memref_ops.empty()) { - if (auto subviewOp = - dyn_cast(src_ancestor_memref_ops[0])) { - // Init. offsets - extractOffsetsFromSubview(subviewOp, rewriter, src_offsets); - // Init. memref type - src_memref_ty = subviewOp.getSourceType(); - src = subviewOp.getSource(); - } else if (auto transposeOp = - dyn_cast(src_ancestor_memref_ops[0])) { - // Init. memref type - src_memref_ty = llvm::cast(transposeOp.getIn().getType()); - src = transposeOp.getIn(); - // Init. offsets - src_offsets.clear(); - for (unsigned i = 0; i < transposeOp.getPermutation().getNumInputs(); i++) - src_offsets.push_back(constZero); - } else if (auto expandShapeOp = dyn_cast( - src_ancestor_memref_ops[0])) { - // Init. memref type - src_memref_ty = - llvm::cast(expandShapeOp.getViewSource().getType()); - src = expandShapeOp.getViewSource(); - // Init. offsets - src_offsets.clear(); - for (unsigned i = 0; i < expandShapeOp.getReassociationIndices().size(); - i++) - src_offsets.push_back(constZero); - } else if (auto castOp = - dyn_cast(src_ancestor_memref_ops[0])) { - // Init. memref type - src_memref_ty = llvm::cast(castOp.getViewSource().getType()); - src = castOp.getViewSource(); - // Init. offsets - src_offsets.clear(); - for (unsigned i = 0; i < src_memref_ty.getRank(); i++) - src_offsets.push_back(constZero); - } - } else { - src_offsets = dmaOp.getSrcOffsets(); - src_sizes = dmaOp.getSrcSizes(); - src_strides = dmaOp.getSrcStrides(); - } - - MemRefType dst_memref_ty; - if (!dst_ancestor_memref_ops.empty()) { - if (auto subviewOp = - dyn_cast(dst_ancestor_memref_ops[0])) { - // Init. offsets - extractOffsetsFromSubview(subviewOp, rewriter, dst_offsets); - // Init. memref type - dst_memref_ty = subviewOp.getSourceType(); - dst = subviewOp.getSource(); - } else if (auto transposeOp = - dyn_cast(dst_ancestor_memref_ops[0])) { - // Init. memref type - dst_memref_ty = llvm::cast(transposeOp.getIn().getType()); - dst = transposeOp.getIn(); - // Init. offsets - dst_offsets.clear(); - for (unsigned i = 0; i < transposeOp.getPermutation().getNumInputs(); i++) - dst_offsets.push_back(constZero); - } else if (auto expandShapeOp = dyn_cast( - dst_ancestor_memref_ops[0])) { - // Init. memref type - dst_memref_ty = - llvm::cast(expandShapeOp.getViewSource().getType()); - dst = expandShapeOp.getViewSource(); - // Init. offsets - dst_offsets.clear(); - for (unsigned i = 0; i < expandShapeOp.getReassociationIndices().size(); - i++) - dst_offsets.push_back(constZero); - } else if (auto castOp = - dyn_cast(dst_ancestor_memref_ops[0])) { - // Init. memref type - dst_memref_ty = llvm::cast(castOp.getViewSource().getType()); - dst = castOp.getViewSource(); - // Init. offsets - dst_offsets.clear(); - for (unsigned i = 0; i < dst_memref_ty.getRank(); i++) - dst_offsets.push_back(constZero); - } - } else { - dst_offsets = dmaOp.getDstOffsets(); - dst_sizes = dmaOp.getDstSizes(); - dst_strides = dmaOp.getDstStrides(); - } - - for (auto memrefOp : src_ancestor_memref_ops) { - if (auto transposeOp = dyn_cast(memrefOp)) { - // Init. memref type - src_memref_ty = - inferTransposeResultType(src_memref_ty, transposeOp.getPermutation()); - // Init. offsets - if (transposeOp.getPermutation().getNumInputs() != src_offsets.size()) - continue; - src_offsets = - applyPermutationMap(transposeOp.getPermutation(), src_offsets); - } else if (auto expandShapeOp = dyn_cast(memrefOp)) { - // Init. offsets - for (int i = (int)expandShapeOp.getReassociationIndices().size() - 1; - i >= 0; i--) { - if (expandShapeOp.getReassociationIndices()[i].size() <= 1) - continue; - for (unsigned j = 1; - j < expandShapeOp.getReassociationIndices()[i].size(); j++) - src_offsets.insert(src_offsets.begin() + i, - rewriter.create(loc, 0)); - } - // Init. memref type - FailureOr compute_expand = - memref::ExpandShapeOp::computeExpandedType( - src_memref_ty, expandShapeOp.getResultType().getShape(), - expandShapeOp.getReassociationIndices()); - if (failed(compute_expand)) { - assert(false); - } else { - src_memref_ty = *compute_expand; - } - } else if (auto subviewOp = dyn_cast(memrefOp)) { - // Check if subview is rank reduced - if (subviewOp.getSourceType().getRank() > subviewOp.getType().getRank()) - src_memref_ty = llvm::cast( - memref::SubViewOp::inferRankReducedResultType( - subviewOp.getType().getShape(), src_memref_ty, - subviewOp.getStaticOffsets(), subviewOp.getStaticSizes(), - subviewOp.getStaticStrides())); - else - src_memref_ty = - llvm::cast(memref::SubViewOp::inferResultType( - src_memref_ty, subviewOp.getStaticOffsets(), - subviewOp.getStaticSizes(), subviewOp.getStaticStrides())); - } else if (auto castOp = dyn_cast(memrefOp)) { - // Init. memref type - src_memref_ty = llvm::cast(castOp.getResult().getType()); - } - } - - for (auto memrefOp : dst_ancestor_memref_ops) { - if (auto transposeOp = dyn_cast(memrefOp)) { - // Init. memref type - dst_memref_ty = - inferTransposeResultType(dst_memref_ty, transposeOp.getPermutation()); - // Init. offsets - if (transposeOp.getPermutation().getNumInputs() != dst_offsets.size()) - continue; - dst_offsets = - applyPermutationMap(transposeOp.getPermutation(), dst_offsets); - } else if (auto expandShapeOp = dyn_cast(memrefOp)) { - // Init. offsets - for (int i = (int)expandShapeOp.getReassociationIndices().size() - 1; - i >= 0; i--) { - if (expandShapeOp.getReassociationIndices()[i].size() <= 1) - continue; - for (unsigned j = 1; - j < expandShapeOp.getReassociationIndices()[i].size(); j++) - dst_offsets.insert(dst_offsets.begin() + i, - rewriter.create(loc, 0)); - } - // Init. memref type - FailureOr compute_expand = - memref::ExpandShapeOp::computeExpandedType( - dst_memref_ty, expandShapeOp.getResultType().getShape(), - expandShapeOp.getReassociationIndices()); - if (failed(compute_expand)) { - assert(false); - } else { - dst_memref_ty = *compute_expand; - } - } else if (auto subviewOp = dyn_cast(memrefOp)) { - if (subviewOp.getSourceType().getRank() > subviewOp.getType().getRank()) - dst_memref_ty = llvm::cast( - memref::SubViewOp::inferRankReducedResultType( - subviewOp.getType().getShape(), dst_memref_ty, - subviewOp.getStaticOffsets(), subviewOp.getStaticSizes(), - subviewOp.getStaticStrides())); - else - dst_memref_ty = - llvm::cast(memref::SubViewOp::inferResultType( - dst_memref_ty, subviewOp.getStaticOffsets(), - subviewOp.getStaticSizes(), subviewOp.getStaticStrides())); - } else if (auto castOp = dyn_cast(memrefOp)) { - // Init. memref type - dst_memref_ty = llvm::cast(castOp.getResult().getType()); - } - } - - if (src_ancestor_memref_ops.size()) { - src_strides = extractStridesFromMemrefType(src_memref_ty, rewriter); - src_sizes = extractSizesFromMemrefType(src_memref_ty, rewriter); - } - if (dst_ancestor_memref_ops.size()) { - dst_strides = extractStridesFromMemrefType(dst_memref_ty, rewriter); - dst_sizes = extractSizesFromMemrefType(dst_memref_ty, rewriter); - } - - SmallVector deps; - SmallVector tys; - - if (failed(canonicalizeAIRDmaOperands( - rewriter, src_offsets, src_sizes, src_strides, - llvm::cast(src.getType()))) || - failed(canonicalizeAIRDmaOperands( - rewriter, dst_offsets, dst_sizes, dst_strides, - llvm::cast(dst.getType())))) { - assert(false); - } - auto new_dma = rewriter.create( - loc, tys, deps, dst, dst_offsets, dst_sizes, dst_strides, src, - src_offsets, src_sizes, src_strides); - - assert(!new_dma.getSrcMemref().getDefiningOp()); - assert(!new_dma.getDstMemref().getDefiningOp()); - - dmaOp->erase(); - - return success(); -} - struct CopyToDmaPass : public air::impl::CopyToDmaBase { CopyToDmaPass() = default; @@ -1539,88 +1171,6 @@ struct CopyToDmaPass : public air::impl::CopyToDmaBase { LLVM_DEBUG(llvm::outs() << "output\n"); LLVM_DEBUG(module.print(llvm::outs())); - - // Condense memref data pattern reordering ops, including memref.subview, - // memref.tranpose and memref.expand_shape into air.dma_memcpy_nd op's - // offsets, sizes and strides fields. - auto scope = getOperation(); - std::vector, - std::vector>> - dma_ops; - - scope->walk([&](xilinx::air::DmaMemcpyNdOp dmaOp) { - bool src_condense = false; - if (auto src_defop = dmaOp.getSrcMemref().getDefiningOp()) { - src_condense |= isa(src_defop); - src_condense |= isa(src_defop); - src_condense |= isa(src_defop); - } - bool dst_condense = false; - if (auto dst_defop = dmaOp.getDstMemref().getDefiningOp()) { - dst_condense |= isa(dst_defop); - dst_condense |= isa(dst_defop); - dst_condense |= isa(dst_defop); - } - if (src_condense || dst_condense) { - // Fields in the tuple: (1) dma op, (2) list of memref ops producing the - // src memref, and (3) list of memref ops producing the dst memref. - std::tuple, - std::vector> - log_entry; - std::get<0>(log_entry) = dmaOp; - if (src_condense) { - Operation *ancestor = dmaOp.getSrcMemref().getDefiningOp(); - bool exit = false; - while (ancestor && !exit) { - if (auto transpose_anc = dyn_cast(ancestor)) { - std::get<1>(log_entry).push_back(ancestor); - ancestor = transpose_anc.getIn().getDefiningOp(); - } else if (auto expand_anc = - dyn_cast(ancestor)) { - std::get<1>(log_entry).push_back(ancestor); - ancestor = expand_anc.getSrc().getDefiningOp(); - } else if (auto subview_anc = - dyn_cast(ancestor)) { - std::get<1>(log_entry).push_back(ancestor); - ancestor = subview_anc.getSource().getDefiningOp(); - } else if (auto cast_anc = dyn_cast(ancestor)) { - std::get<1>(log_entry).push_back(ancestor); - ancestor = cast_anc.getViewSource().getDefiningOp(); - } else - exit = true; - } - } - if (dst_condense) { - Operation *ancestor = dmaOp.getDstMemref().getDefiningOp(); - bool exit = false; - while (ancestor && !exit) { - if (auto transpose_anc = dyn_cast(ancestor)) { - std::get<2>(log_entry).push_back(ancestor); - ancestor = transpose_anc.getIn().getDefiningOp(); - } else if (auto expand_anc = - dyn_cast(ancestor)) { - std::get<2>(log_entry).push_back(ancestor); - ancestor = expand_anc.getSrc().getDefiningOp(); - } else if (auto subview_anc = - dyn_cast(ancestor)) { - std::get<2>(log_entry).push_back(ancestor); - ancestor = subview_anc.getSource().getDefiningOp(); - } else if (auto cast_anc = dyn_cast(ancestor)) { - std::get<2>(log_entry).push_back(ancestor); - ancestor = cast_anc.getViewSource().getDefiningOp(); - } else - exit = true; - } - } - dma_ops.push_back(log_entry); - } - }); - for (auto dmaOp : dma_ops) { - if (failed(condenseMemrefDataReorderingToAIRDma( - std::get<0>(dmaOp), std::get<1>(dmaOp), std::get<2>(dmaOp)))) { - return signalPassFailure(); - } - } } }; diff --git a/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir b/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir deleted file mode 100644 index a7162173e..000000000 --- a/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir +++ /dev/null @@ -1,124 +0,0 @@ -//===- condense_memref_ops_to_air_memcpy.mlir ------------------*- MLIR -*-===// -// -// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. -// SPDX-License-Identifier: MIT -// -//===----------------------------------------------------------------------===// - -// RUN: air-opt %s -air-copy-to-dma -canonicalize -cse | FileCheck %s - -// Memref::SubviewOp, memref::ExpandShapeOp and memref::TransposeOp folding. - -// CHECK: %[[CST128:.*]] = arith.constant 128 : index -// CHECK: %[[CST32:.*]] = arith.constant 32 : index -// CHECK: %[[CST8:.*]] = arith.constant 8 : index -// CHECK: %[[CST16:.*]] = arith.constant 16 : index -// CHECK: %[[CST0:.*]] = arith.constant 0 : index -// CHECK: %[[CST1:.*]] = arith.constant 1 : index -// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%{{.*}}, %[[CST0]]] [%[[CST8]], %[[CST16]]] [%[[CST16]], %[[CST1]]]) : (memref<1x1x8x16xi32, 1>, memref<8x16xi32>) -// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0]], %{{.*}}] [%[[CST16]], %[[CST16]]] [%[[CST32]], %[[CST1]]]) : (memref<1x1x16x16xi32, 1>, memref<16x32xi32>) -// CHECK: air.herd @herd_0 -// CHECK: %[[CST32_0:.*]] = arith.constant 32 : index -// CHECK: %[[CST256_0:.*]] = arith.constant 256 : index -// CHECK: %[[CST4_0:.*]] = arith.constant 4 : index -// CHECK: %[[CST2_0:.*]] = arith.constant 2 : index -// CHECK: %[[CST1_0:.*]] = arith.constant 1 : index -// CHECK: %[[CST16_0:.*]] = arith.constant 16 : index -// CHECK: %[[CST64_0:.*]] = arith.constant 64 : index -// CHECK: %[[CST8_0:.*]] = arith.constant 8 : index -// CHECK: %[[CST128_0:.*]] = arith.constant 128 : index -// CHECK: %[[CST0_0:.*]] = arith.constant 0 : index -// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%{{.*}}, %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST2_0]], %[[CST4_0]], %[[CST8_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST8_0]], %[[CST64_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x8x16xi32, 1>) -// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0_0]], %{{.*}}, %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST2_0]], %[[CST8_0]], %[[CST8_0]]] [%[[CST256_0]], %[[CST256_0]], %[[CST8_0]], %[[CST128_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x16x16xi32, 1>) -// CHECK: air.dma_memcpy_nd (%{{.*}}[%{{.*}}, %{{.*}}, %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST8_0]], %[[CST16_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST16_0]], %[[CST1_0]]], %{{.*}}[%[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST4_0]], %[[CST2_0]], %[[CST8_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST32_0]], %[[CST8_0]], %[[CST64_0]], %[[CST1_0]]]) : (memref<1x1x8x16xi32, 1>, memref<1x1x2x2x4x8xi32, 2>) -// CHECK: air.dma_memcpy_nd (%{{.*}}[%{{.*}}, %{{.*}}] [%[[CST8]], %[[CST16]]] [%[[CST32]], %[[CST1]]], %{{.*}}[%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]] [%[[CST1]], %[[CST1]], %[[CST8]], %[[CST16]]] [%[[CST128]], %[[CST128]], %[[CST16]], %[[CST1]]]) : (memref<8x32xi32>, memref<1x1x8x16xi32, 1>) - -#map = affine_map<()[s0] -> (s0 * 8)> -#map1 = affine_map<()[s0] -> (s0 * 16)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)> -#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)> -#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)> -func.func @func0(%0 : memref<8x16xi32>, %1 : memref<16x32xi32>, %2 : memref<8x32xi32>) { - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - air.launch (%arg0, %arg1) in (%arg2=%c1, %arg3=%c2) args(%arg4=%0, %arg5=%1, %arg6=%2) : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> { - air.segment @segment_0 args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> { - %c1_0 = arith.constant 1 : index - %3 = affine.apply #map()[%arg7] - %4 = affine.apply #map1()[%arg8] - %subview = memref.subview %arg9[%3, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1], offset: ?>> - %subview_1 = memref.subview %arg10[0, %4] [16, 16] [1, 1] : memref<16x32xi32> to memref<16x16xi32, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %arg11[%3, %4] [8, 16] [1, 1] : memref<8x32xi32> to memref<8x16xi32, strided<[32, 1], offset: ?>> - %alloc = memref.alloc() : memref<1x1x8x16xi32, 1> - %transpose = memref.transpose %subview (d0, d1) -> (d0, d1) : memref<8x16xi32, strided<[16, 1], offset: ?>> to memref<8x16xi32, strided<[16, 1], offset: ?>> - air.dma_memcpy_nd (%alloc[] [] [], %transpose[] [] []) : (memref<1x1x8x16xi32, 1>, memref<8x16xi32, strided<[16, 1], offset: ?>>) - %alloc_3 = memref.alloc() : memref<1x1x16x16xi32, 1> - %transpose_4 = memref.transpose %subview_1 (d0, d1) -> (d0, d1) : memref<16x16xi32, strided<[32, 1], offset: ?>> to memref<16x16xi32, strided<[32, 1], offset: ?>> - air.dma_memcpy_nd (%alloc_3[] [] [], %transpose_4[] [] []) : (memref<1x1x16x16xi32, 1>, memref<16x16xi32, strided<[32, 1], offset: ?>>) - %alloc_5 = memref.alloc() : memref<1x1x8x16xi32, 1> - air.herd @herd_0 tile (%arg12, %arg13) in (%arg14=%c1_0, %arg15=%c1_0) args(%arg16=%alloc, %arg17=%alloc_3, %arg18=%alloc_5) : memref<1x1x8x16xi32, 1>, memref<1x1x16x16xi32, 1>, memref<1x1x8x16xi32, 1> { - %c0_i32 = arith.constant 0 : i32 - %subview_8 = memref.subview %arg16[%arg12, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> - %subview_9 = memref.subview %arg17[0, %arg13, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<1x1x16x16xi32, 1> to memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1> - %subview_10 = memref.subview %arg18[%arg12, %arg13, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> - %alloc_11 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2> - %expand_shape = memref.expand_shape %subview_8 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 2, 4, 2, 8]: memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> into memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1> - %transpose_12 = memref.transpose %expand_shape (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1> - air.dma_memcpy_nd (%alloc_11[] [] [], %transpose_12[] [] []) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1>) - %alloc_13 = memref.alloc() : memref<1x1x2x2x8x8xi32, 2> - %expand_shape_14 = memref.expand_shape %subview_9 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 2, 8, 2, 8] : memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1> into memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1> - %transpose_15 = memref.transpose %expand_shape_14 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1> - air.dma_memcpy_nd (%alloc_13[] [] [], %transpose_15[] [] []) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1>) - %alloc_16 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2> - %transpose_17 = memref.transpose %alloc_16 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x2x2x4x8xi32, 2> to memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2> - air.dma_memcpy_nd (%subview_10[] [] [], %transpose_17[] [] []) : (memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1>, memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2>) - memref.dealloc %alloc_11 : memref<1x1x2x2x4x8xi32, 2> - memref.dealloc %alloc_13 : memref<1x1x2x2x8x8xi32, 2> - memref.dealloc %alloc_16 : memref<1x1x2x2x4x8xi32, 2> - } - %subview_6 = memref.subview %alloc_5[0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<8x16xi32, 1> - %transpose_7 = memref.transpose %subview_6 (d0, d1) -> (d0, d1) : memref<8x16xi32, 1> to memref<8x16xi32, strided<[16, 1]>, 1> - air.dma_memcpy_nd (%subview_2[] [] [], %transpose_7[] [] []) : (memref<8x16xi32, strided<[32, 1], offset: ?>>, memref<8x16xi32, strided<[16, 1]>, 1>) - memref.dealloc %alloc_3 : memref<1x1x16x16xi32, 1> - memref.dealloc %alloc : memref<1x1x8x16xi32, 1> - memref.dealloc %alloc_5 : memref<1x1x8x16xi32, 1> - } - } - return -} - -// Memref::CastOp folding. - -// CHECK: air.herd @herd_0 {{.*}} args(%[[ARG0:.*]]=%{{.*}}, %[[ARG1:.*]]=%{{.*}}) -// CHECK-DAG: %[[CST4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[CST3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[CST8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[CST64:.*]] = arith.constant 64 : index -// CHECK-DAG: %[[CST256:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[CST768:.*]] = arith.constant 768 : index -// CHECK-DAG: %[[CST0:.*]] = arith.constant 0 : index -// CHECK: air.dma_memcpy_nd (%[[ARG1]][] [] [], %[[ARG0]][%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]] [%[[CST3]], %[[CST3]], %[[CST4]], %[[CST1]], %[[CST8]], %[[CST8]]] [%[[CST768]], %[[CST256]], %[[CST64]], %[[CST8]], %[[CST8]], %[[CST1]]]) : (memref<3x3x4x1x8x8xi8, 2 : i32>, memref<3x3x32x8xi8, 1 : i32>) -// CHECK: } - -func.func @func1() { - %c8 = arith.constant 8 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - air.launch (%arg3, %arg4, %arg5, %arg6) in (%arg7=%c2, %arg8=%c3, %arg9=%c3, %arg10=%c8) { - air.segment @segment_0 { - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %alloc = memref.alloc() : memref<3x3x4x1x8x8xi8, 2 : i32> - %alloc_0 = memref.alloc() : memref<3x3x32x8xi8, 1 : i32> - air.herd @herd_0 tile (%arg11, %arg12) in (%arg13=%c4, %arg14=%c1) args(%arg15=%alloc_0, %arg16=%alloc) : memref<3x3x32x8xi8, 1 : i32>, memref<3x3x4x1x8x8xi8, 2 : i32> { - %cast = memref.cast %arg15 : memref<3x3x32x8xi8, 1 : i32> to memref<3x3x32x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1 : i32> - %expand_shape = memref.expand_shape %cast [[0], [1], [2, 3], [4, 5]] output_shape [3, 3, 4, 8, 1, 8] : memref<3x3x32x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1 : i32> into memref<3x3x4x8x1x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32> - %transpose = memref.transpose %expand_shape (d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5) : memref<3x3x4x8x1x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32> to memref<3x3x4x1x8x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32> - air.dma_memcpy_nd (%arg16[] [] [], %transpose[] [] []) : (memref<3x3x4x1x8x8xi8, 2 : i32>, memref<3x3x4x1x8x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32>) - } - } - } - return -} From f0bb9200e893c6fd27c18cf07316c8c3cf1ca322 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Tue, 22 Oct 2024 21:02:03 -0700 Subject: [PATCH 2/5] Reimplement condenseMemrefDataReorderingToAIRDma as canonicalizer to air.channel.put/get and air.dma_memcpy_nd --- mlir/include/air/Dialect/AIR/AIR.td | 3 + mlir/lib/Dialect/AIR/IR/AIRDialect.cpp | 385 ++++++++++++++++++ .../tranpose_linalg_cpy_to_4d_air_memcpy.mlir | 4 +- mlir/test/Dialect/AIR/air_canonicalize.mlir | 215 ++++++++++ 4 files changed, 605 insertions(+), 2 deletions(-) diff --git a/mlir/include/air/Dialect/AIR/AIR.td b/mlir/include/air/Dialect/AIR/AIR.td index a4ecee890..85612b079 100644 --- a/mlir/include/air/Dialect/AIR/AIR.td +++ b/mlir/include/air/Dialect/AIR/AIR.td @@ -281,6 +281,7 @@ def air_DmaMemcpyNdOp: air_Op<"dma_memcpy_nd", return -1; } }]; + let hasCanonicalizer = 1; } def air_WaitAllOp: air_Op<"wait_all", [air_AsyncOpInterface]> { @@ -404,6 +405,7 @@ def air_ChannelPutOp : air_Op<"channel.put", [air_AsyncOpInterface, return -1; } }]; + let hasCanonicalizer = 1; } def air_ChannelGetOp : air_Op<"channel.get", [air_AsyncOpInterface, @@ -445,6 +447,7 @@ def air_ChannelGetOp : air_Op<"channel.get", [air_AsyncOpInterface, return -1; } }]; + let hasCanonicalizer = 1; } // AIR asynchronous region for dynamic event dispatching. diff --git a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp index 15d1aba5a..09203abf4 100644 --- a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp +++ b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp @@ -9,6 +9,7 @@ #include "air/Dialect/AIR/AIRDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" @@ -1108,6 +1109,390 @@ void WaitAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(FoldWaitAll); } +// Get strides from MemRefType. +static SmallVector extractStridesFromMemrefType(MemRefType memrefTy, + OpBuilder &builder) { + SmallVector strides; + int64_t offset; + SmallVector layout_strides; + auto successStrides = getStridesAndOffset(memrefTy, layout_strides, offset); + if (failed(successStrides)) { + llvm::outs() << "Failed to get strides\n"; + return strides; + } + + for (auto s : layout_strides) + strides.push_back( + builder.create(builder.getUnknownLoc(), s)); + + return strides; +} + +// Get sizes from MemRefType. +static SmallVector extractSizesFromMemrefType(MemRefType memrefTy, + OpBuilder &builder) { + SmallVector sizes; + for (auto s : memrefTy.getShape()) + sizes.push_back( + builder.create(builder.getUnknownLoc(), s)); + return sizes; +} + +// Get offsets from memref::SubviewOp. +static void extractOffsetsFromSubview(memref::SubViewOp subview, + OpBuilder &builder, + SmallVector &offsets) { + auto subview_offsets = subview.getOffsets().begin(); + auto static_offsets = subview.getStaticOffsets(); + auto loc = subview.getLoc(); + + for (auto o : static_offsets) { + if (o >= 0) + offsets.push_back(builder.create(loc, o)); + else + offsets.push_back(*subview_offsets++); + } +} + +static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder, + SmallVector &offsets, + SmallVector &sizes, + SmallVector &strides, + MemRefType memref) { + // Increase vector sizes up to memref size. When offsets, sizes and strides + // are all empty, then it implies that the whole memref is accessed in the + // default order. + auto max_dim_size = + std::max(std::max(offsets.size(), sizes.size()), strides.size()); + auto target_dim_size = std::max(max_dim_size, (size_t)memref.getRank()); + if (max_dim_size && offsets.size() < target_dim_size) { + for (unsigned i = offsets.size(); i < target_dim_size; i++) { + offsets.insert(offsets.begin(), builder.create( + builder.getUnknownLoc(), 0)); + } + } + if (max_dim_size && sizes.size() < target_dim_size) { + for (unsigned i = sizes.size(); i < target_dim_size; i++) { + sizes.insert(sizes.begin(), builder.create( + builder.getUnknownLoc(), 1)); + } + } + int memref_size = 1; + for (auto size : memref.getShape()) + memref_size *= size; + if (max_dim_size && strides.size() < target_dim_size) { + for (unsigned i = strides.size(); i < target_dim_size; i++) { + strides.insert(strides.begin(), + builder.create( + builder.getUnknownLoc(), memref_size)); + } + } + + // Reduce highest dimensions if more than memref size + while (strides.size() > target_dim_size && getConstantIntValue(strides[0]) && + *getConstantIntValue(strides[0]) == memref_size) { + strides.erase(strides.begin()); + } + while (sizes.size() > target_dim_size && getConstantIntValue(sizes[0]) && + *getConstantIntValue(sizes[0]) == 1) { + sizes.erase(sizes.begin()); + } + while (offsets.size() > std::min(sizes.size(), strides.size()) && + getConstantIntValue(offsets[0]) && + *getConstantIntValue(offsets[0]) == 0) { + offsets.erase(offsets.begin()); + } + + if (offsets.size() != sizes.size() || sizes.size() != strides.size()) + return failure(); + + return success(); +} + +static LogicalResult ComposeMemrefOp(Value memref, PatternRewriter &rewriter, + Value &input_memref, + SmallVector &offsets, + SmallVector &sizes, + SmallVector &strides) { + + auto memref_type = llvm::dyn_cast(memref.getType()); + if (!memref_type) + return failure(); + auto defop = memref.getDefiningOp(); + if (!defop) + return failure(); + auto loc = defop->getLoc(); + + // Get a chain of memref ops that produce the memref consumed by the memcpy + // op. + std::vector memrefOpVec; + bool exit = false; + while (defop && !exit) { + if (auto transposeOp = dyn_cast(defop)) { + memrefOpVec.push_back(defop); + defop = transposeOp.getIn().getDefiningOp(); + } else if (auto viewLikeOp = dyn_cast(defop)) { + memrefOpVec.push_back(defop); + defop = viewLikeOp.getViewSource().getDefiningOp(); + } else + exit = true; + } + if (memrefOpVec.empty()) + return failure(); + + // Revert the vector of memref ops, as it was built with push_back. + std::reverse(memrefOpVec.begin(), memrefOpVec.end()); + + // Init. memref type and offsets at the front of the vector of memref ops. + auto constZero = rewriter.create(loc, 0); + MemRefType input_ty; + if (auto subviewOp = dyn_cast(memrefOpVec[0])) { + // Init. offsets + extractOffsetsFromSubview(subviewOp, rewriter, offsets); + // Init. memref type + input_ty = subviewOp.getSourceType(); + input_memref = subviewOp.getViewSource(); + } else if (auto transposeOp = dyn_cast(memrefOpVec[0])) { + // Init. memref type + input_ty = llvm::cast(transposeOp.getIn().getType()); + input_memref = transposeOp.getIn(); + // Init. offsets + offsets.clear(); + for (unsigned i = 0; i < transposeOp.getPermutation().getNumInputs(); i++) + offsets.push_back(constZero); + } else if (auto viewLikeOp = dyn_cast(memrefOpVec[0])) { + // Init. memref type + input_ty = llvm::cast(viewLikeOp.getViewSource().getType()); + input_memref = viewLikeOp.getViewSource(); + // Init. offsets + offsets.clear(); + for (unsigned i = 0; i < input_ty.getRank(); i++) + offsets.push_back(constZero); + } else + return failure(); + + // Compose memref type as the memref propagates through the chain of memref + // ops. + for (auto memrefOp : memrefOpVec) { + if (auto transposeOp = dyn_cast(memrefOp)) { + if (transposeOp.getPermutation().getNumInputs() != offsets.size()) + continue; + offsets = + applyPermutationMap(transposeOp.getPermutation(), offsets); + } else if (auto expandShapeOp = dyn_cast(memrefOp)) { + // Init. offsets + for (int i = (int)expandShapeOp.getReassociationIndices().size() - 1; + i >= 0; i--) { + if (expandShapeOp.getReassociationIndices()[i].size() <= 1) + continue; + for (unsigned j = 1; + j < expandShapeOp.getReassociationIndices()[i].size(); j++) + offsets.insert(offsets.begin() + i, + rewriter.create(loc, 0)); + } + } else if (auto subviewOp = dyn_cast(memrefOp)) { + if (subviewOp != memrefOpVec.front() && !subviewOp.hasZeroOffset()) + subviewOp->emitOpError( + "is not the source op in a chain of memref layout transformation " + "ops, but applies a non-zero offset. This feature is NYI, and " + "leads to unexpected behaviour."); + } + } + + // Memref type at sink memref op. + input_ty = + llvm::cast(memrefOpVec.back()->getResultTypes().front()); + + // Compose sizes and strides from the output memref type's layout. + strides = extractStridesFromMemrefType(input_ty, rewriter); + sizes = extractSizesFromMemrefType(input_ty, rewriter); + + return canonicalizeAIRDmaOperands(rewriter, offsets, sizes, strides, + input_ty); +} + +// +// Dma op +// + +static LogicalResult +ComposeMemrefOpOnDmaMemcpyNdSrc(DmaMemcpyNdOp op, PatternRewriter &rewriter) { + + Value memref = op.getSrcMemref(); + if (!memref) + return failure(); + auto loc = op->getLoc(); + Value input_memref; + SmallVector offsets, sizes, strides; + offsets = op.getSrcOffsets(); + if (!offsets.empty()) + return failure(); + sizes = op.getSrcSizes(); + if (!sizes.empty()) + return failure(); + strides = op.getSrcStrides(); + if (!strides.empty()) + return failure(); + + if (failed(ComposeMemrefOp(memref, rewriter, input_memref, offsets, sizes, + strides))) { + return failure(); + } + + auto newOp = rewriter.create( + loc, op->getResultTypes(), op.getAsyncDependencies(), op.getDstMemref(), + op.getDstOffsets(), op.getDstSizes(), op.getDstStrides(), input_memref, + offsets, sizes, strides); + + for (unsigned i = 0; i < op->getNumResults(); i++) { + op->getResult(i).replaceAllUsesWith(newOp->getResult(i)); + } + + rewriter.eraseOp(op); + + return success(); +} + +static LogicalResult +ComposeMemrefOpOnDmaMemcpyNdDst(DmaMemcpyNdOp op, PatternRewriter &rewriter) { + + Value memref = op.getDstMemref(); + if (!memref) + return failure(); + auto loc = op->getLoc(); + Value input_memref; + SmallVector offsets, sizes, strides; + offsets = op.getDstOffsets(); + if (!offsets.empty()) + return failure(); + sizes = op.getDstSizes(); + if (!sizes.empty()) + return failure(); + strides = op.getDstStrides(); + if (!strides.empty()) + return failure(); + + if (failed(ComposeMemrefOp(memref, rewriter, input_memref, offsets, sizes, + strides))) { + return failure(); + } + + auto newOp = rewriter.create( + loc, op->getResultTypes(), op.getAsyncDependencies(), input_memref, + offsets, sizes, strides, op.getSrcMemref(), op.getSrcOffsets(), + op.getSrcSizes(), op.getSrcStrides()); + + for (unsigned i = 0; i < op->getNumResults(); i++) { + op->getResult(i).replaceAllUsesWith(newOp->getResult(i)); + } + + rewriter.eraseOp(op); + + return success(); +} + +void DmaMemcpyNdOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(ComposeMemrefOpOnDmaMemcpyNdSrc); + patterns.add(ComposeMemrefOpOnDmaMemcpyNdDst); +} + +// +// Channel put op +// + +static LogicalResult ComposeMemrefOpOnChannelPut(ChannelPutOp op, + PatternRewriter &rewriter) { + + Value memref = op.getMemref(); + if (!memref) + return failure(); + auto loc = op->getLoc(); + + // Init. memref type and offsets from memref's defining op's input type + Value input_memref; + SmallVector offsets, sizes, strides; + offsets = op.getOffsets(); + if (!offsets.empty()) + return failure(); + sizes = op.getSizes(); + if (!sizes.empty()) + return failure(); + strides = op.getStrides(); + if (!strides.empty()) + return failure(); + + if (failed(ComposeMemrefOp(memref, rewriter, input_memref, offsets, sizes, + strides))) { + return failure(); + } + + auto newOp = rewriter.create( + loc, op->getResultTypes(), op.getAsyncDependencies(), op.getChanName(), + op.getIndices(), input_memref, offsets, sizes, strides); + + for (unsigned i = 0; i < op->getNumResults(); i++) { + op->getResult(i).replaceAllUsesWith(newOp->getResult(i)); + } + + rewriter.eraseOp(op); + + return success(); +} + +void ChannelPutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(ComposeMemrefOpOnChannelPut); +} + +// +// Channel get op +// + +static LogicalResult ComposeMemrefOpOnChannelGet(ChannelGetOp op, + PatternRewriter &rewriter) { + + Value memref = op.getMemref(); + if (!memref) + return failure(); + auto loc = op->getLoc(); + + // Init. memref type and offsets from memref's defining op's input type + Value input_memref; + SmallVector offsets, sizes, strides; + offsets = op.getOffsets(); + if (!offsets.empty()) + return failure(); + sizes = op.getSizes(); + if (!sizes.empty()) + return failure(); + strides = op.getStrides(); + if (!strides.empty()) + return failure(); + + if (failed(ComposeMemrefOp(memref, rewriter, input_memref, offsets, sizes, + strides))) { + return failure(); + } + + auto newOp = rewriter.create( + loc, op->getResultTypes(), op.getAsyncDependencies(), op.getChanName(), + op.getIndices(), input_memref, offsets, sizes, strides); + + for (unsigned i = 0; i < op->getNumResults(); i++) { + op->getResult(i).replaceAllUsesWith(newOp->getResult(i)); + } + + rewriter.eraseOp(op); + + return success(); +} + +void ChannelGetOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(ComposeMemrefOpOnChannelGet); +} + // // Channel op // diff --git a/mlir/test/Conversion/ConvertToAIR/tranpose_linalg_cpy_to_4d_air_memcpy.mlir b/mlir/test/Conversion/ConvertToAIR/tranpose_linalg_cpy_to_4d_air_memcpy.mlir index 3b7259c88..e52539281 100644 --- a/mlir/test/Conversion/ConvertToAIR/tranpose_linalg_cpy_to_4d_air_memcpy.mlir +++ b/mlir/test/Conversion/ConvertToAIR/tranpose_linalg_cpy_to_4d_air_memcpy.mlir @@ -6,11 +6,11 @@ // //===-------------------------------------------------------------------------------===// -// RUN: air-opt %s -air-copy-to-dma | FileCheck %s +// RUN: air-opt %s -air-copy-to-dma -canonicalize | FileCheck %s // CHECK: func.func @test(%[[ARG0:.*]] // CHECK: scf.for %[[ARG1:.*]] = %c0 to %c128 step %c32 { -// CHECK: air.dma_memcpy_nd (%alloc[%[[ARG1:.*]], %c0_{{.*}}, %c0_{{.*}}, %c0_{{.*}}] [%c32_{{.*}}, %c8_{{.*}}, %c8_{{.*}}, %c16] [%c1024, %c128_{{.*}}, %c16_{{.*}}, %c1], %[[ARG0:.*]][%[[ARG1:.*]], %c0_{{.*}}, %c0_{{.*}}, %0] [%c32_{{.*}}, %c8_{{.*}}, %c8_{{.*}}, %c16_{{.*}}] [%c4096, %c64, %c512, %c1_{{.*}}]) +// CHECK: air.dma_memcpy_nd (%alloc[%[[ARG1:.*]], %c0{{.*}}, %c0{{.*}}, %c0{{.*}}] [%c32{{.*}}, %c8{{.*}}, %c8{{.*}}, %c16] [%c1024, %c128{{.*}}, %c16{{.*}}, %c1], %[[ARG0:.*]][%[[ARG1:.*]], %c0{{.*}}, %c0{{.*}}, %0] [%c32{{.*}}, %c8{{.*}}, %c8{{.*}}, %c16{{.*}}] [%c4096, %c64, %c512, %c1{{.*}}]) #map = affine_map<(d0) -> (d0 * 16)> module { func.func @test(%arg1: memref<128x8x8x64xbf16>) -> memref<128x8x8x64xbf16> { diff --git a/mlir/test/Dialect/AIR/air_canonicalize.mlir b/mlir/test/Dialect/AIR/air_canonicalize.mlir index 61ec3ebcb..44b606345 100644 --- a/mlir/test/Dialect/AIR/air_canonicalize.mlir +++ b/mlir/test/Dialect/AIR/air_canonicalize.mlir @@ -158,3 +158,218 @@ func.func @execute_4() -> (memref<1xi32>, !air.async.token) { } return %results, %t : memref<1xi32>, !air.async.token } + +// CHECK: func.func @dma_compose_subview +// CHECK: air.dma_memcpy_nd (%{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c32{{.*}}, %c32{{.*}}] [%c64{{.*}}, %c1{{.*}}], %{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c32{{.*}}, %c32{{.*}}] [%c64{{.*}}, %c1{{.*}}] +func.func @dma_compose_subview() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %0 = memref.alloc() : memref<64x64xbf16, 1> + %1 = memref.alloc() : memref<64x64xbf16, 2> + %subview_4 = memref.subview %0[%c0, %c0] [32, 32] [1, 1] : memref<64x64xbf16, 1> to memref<32x32xbf16, strided<[64, 1], offset: ?>, 1> + %subview_5 = memref.subview %1[%c0, %c0] [32, 32] [1, 1] : memref<64x64xbf16, 2> to memref<32x32xbf16, strided<[64, 1], offset: ?>, 2> + air.dma_memcpy_nd (%subview_4[] [] [], %subview_5[] [] []) : (memref<32x32xbf16, strided<[64, 1], offset: ?>, 1>, memref<32x32xbf16, strided<[64, 1], offset: ?>, 2>) + return +} + +// CHECK: func.func @dma_compose_transpose +// CHECK: air.dma_memcpy_nd (%{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c64{{.*}}, %c128{{.*}}] [%c1{{.*}}, %c64{{.*}}], %{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c64{{.*}}, %c128{{.*}}] [%c1{{.*}}, %c64{{.*}}] +func.func @dma_compose_transpose() { + %0 = memref.alloc() : memref<128x64xbf16, 1> + %1 = memref.alloc() : memref<128x64xbf16, 2> + %transpose_1 = memref.transpose %0 (d0, d1) -> (d1, d0) : memref<128x64xbf16, 1> to memref<64x128xbf16, affine_map<(d0, d1) -> (d0 + d1 * 64)>, 1> + %transpose_2 = memref.transpose %1 (d0, d1) -> (d1, d0) : memref<128x64xbf16, 2> to memref<64x128xbf16, affine_map<(d0, d1) -> (d0 + d1 * 64)>, 2> + air.dma_memcpy_nd (%transpose_1[] [] [], %transpose_2[] [] []) : (memref<64x128xbf16, affine_map<(d0, d1) -> (d0 + d1 * 64)>, 1>, memref<64x128xbf16, affine_map<(d0, d1) -> (d0 + d1 * 64)>, 2>) + return +} + +// CHECK: func.func @dma_compose_expand_shape +// CHECK: air.dma_memcpy_nd (%{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}] [%c2{{.*}}, %c64{{.*}}, %c64{{.*}}] [%c4096{{.*}}, %c64{{.*}}, %c1{{.*}}], %{{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}] [%c2{{.*}}, %c64{{.*}}, %c64{{.*}}] [%c4096{{.*}}, %c64{{.*}}, %c1{{.*}}] +func.func @dma_compose_expand_shape() { + %0 = memref.alloc() : memref<128x64xbf16, 1> + %1 = memref.alloc() : memref<128x64xbf16, 2> + %expand_shape_1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [2, 64, 64] : memref<128x64xbf16, 1> into memref<2x64x64xbf16, 1> + %expand_shape_2 = memref.expand_shape %1 [[0, 1], [2]] output_shape [2, 64, 64] : memref<128x64xbf16, 2> into memref<2x64x64xbf16, 2> + air.dma_memcpy_nd (%expand_shape_1[] [] [], %expand_shape_2[] [] []) : (memref<2x64x64xbf16, 1>, memref<2x64x64xbf16, 2>) + return +} + +// CHECK: func.func @dma_compose_cast +// CHECK: air.dma_memcpy_nd (%{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c128{{.*}}, %c64{{.*}}] [%c64{{.*}}, %c1{{.*}}], %{{.*}}[%c0{{.*}}, %c0{{.*}}] [%c128{{.*}}, %c64{{.*}}] [%c64{{.*}}, %c1{{.*}}] +func.func @dma_compose_cast() { + %0 = memref.alloc() : memref<128x64xbf16, 1> + %1 = memref.alloc() : memref<128x64xbf16, 2> + %cast = memref.cast %0 : memref<128x64xbf16, 1> to memref<128x64xbf16, strided<[64, 1], offset: ?>, 1> + %cast_1 = memref.cast %1 : memref<128x64xbf16, 2> to memref<128x64xbf16, strided<[64, 1], offset: ?>, 2> + air.dma_memcpy_nd (%cast[] [] [], %cast_1[] [] []) : (memref<128x64xbf16, strided<[64, 1], offset: ?>, 1>, memref<128x64xbf16, strided<[64, 1], offset: ?>, 2>) + return +} + +// CHECK: func.func @channel_compose_subview +// CHECK: air.channel.put @channel[] ({{.*}}[{{.*}}, {{.*}}] [%c32{{.*}}, %c32{{.*}}] [%c64{{.*}}, %c1{{.*}}] +// CHECK: %[[V2:.*]] = air.channel.get async @channel[] ({{.*}}[{{.*}}, {{.*}}] [%c32{{.*}}, %c32{{.*}}] [%c64{{.*}}, %c1{{.*}}] +air.channel @channel[2,2] +func.func @channel_compose_subview() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %0 = memref.alloc() : memref<64x64xbf16, 1> + %1 = memref.alloc() : memref<64x64xbf16, 2> + %subview_4 = memref.subview %0[%c0, %c0] [32, 32] [1, 1] : memref<64x64xbf16, 1> to memref<32x32xbf16, strided<[64, 1], offset: ?>, 1> + air.channel.put @channel[] (%subview_4[] [] []) : (memref<32x32xbf16, strided<[64, 1], offset: ?>, 1>) + %subview_5 = memref.subview %1[%c0, %c0] [32, 32] [1, 1] : memref<64x64xbf16, 2> to memref<32x32xbf16, strided<[64, 1], offset: ?>, 2> + %5 = air.channel.get async @channel[] (%subview_5[] [] []) : (memref<32x32xbf16, strided<[64, 1], offset: ?>, 2>) + return +} + +// CHECK: func.func @channel_compose_transpose +// CHECK: air.channel.put @channel[] ({{.*}}[{{.*}}, {{.*}}] [%c64{{.*}}, %c128{{.*}}] [%c1{{.*}}, %c64{{.*}}] +// CHECK: %[[V2:.*]] = air.channel.get async @channel[] ({{.*}}[{{.*}}, {{.*}}] [%c64{{.*}}, %c128{{.*}}] [%c1{{.*}}, %c64{{.*}}] +func.func @channel_compose_transpose() { + %0 = memref.alloc() : memref<128x64xbf16, 1> + %1 = memref.alloc() : memref<128x64xbf16, 2> + %transpose_1 = memref.transpose %0 (d0, d1) -> (d1, d0) : memref<128x64xbf16, 1> to memref<64x128xbf16, affine_map<(d0, d1) -> (d0 + d1 * 64)>, 1> + air.channel.put @channel[] (%transpose_1[] [] []) : (memref<64x128xbf16, affine_map<(d0, d1) -> (d0 + d1 * 64)>, 1>) + %transpose_2 = memref.transpose %1 (d0, d1) -> (d1, d0) : memref<128x64xbf16, 2> to memref<64x128xbf16, affine_map<(d0, d1) -> (d0 + d1 * 64)>, 2> + %5 = air.channel.get async @channel[] (%transpose_2[] [] []) : (memref<64x128xbf16, affine_map<(d0, d1) -> (d0 + d1 * 64)>, 2>) + return +} + +// CHECK: func.func @channel_compose_expand_shape +// CHECK: air.channel.put @channel[] ({{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}] [%c2{{.*}}, %c64{{.*}}, %c64{{.*}}] [%c4096{{.*}}, %c64{{.*}}, %c1{{.*}}] +// CHECK: %[[V2:.*]] = air.channel.get async @channel[] ({{.*}}[%c0{{.*}}, %c0{{.*}}, %c0{{.*}}] [%c2{{.*}}, %c64{{.*}}, %c64{{.*}}] [%c4096{{.*}}, %c64{{.*}}, %c1{{.*}}] +func.func @channel_compose_expand_shape() { + %0 = memref.alloc() : memref<128x64xbf16, 1> + %1 = memref.alloc() : memref<128x64xbf16, 2> + %expand_shape_1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [2, 64, 64] : memref<128x64xbf16, 1> into memref<2x64x64xbf16, 1> + air.channel.put @channel[] (%expand_shape_1[] [] []) : (memref<2x64x64xbf16, 1>) + %expand_shape_2 = memref.expand_shape %1 [[0, 1], [2]] output_shape [2, 64, 64] : memref<128x64xbf16, 2> into memref<2x64x64xbf16, 2> + %5 = air.channel.get async @channel[] (%expand_shape_2[] [] []) : (memref<2x64x64xbf16, 2>) + return +} + +// CHECK: func.func @channel_compose_cast +// CHECK: air.channel.put @channel[] ({{.*}}[%c0{{.*}}, %c0{{.*}}] [%c128{{.*}}, %c64{{.*}}] [%c64{{.*}}, %c1{{.*}}] +// CHECK: %[[V2:.*]] = air.channel.get async @channel[] ({{.*}}[%c0{{.*}}, %c0{{.*}}] [%c128{{.*}}, %c64{{.*}}] [%c64{{.*}}, %c1{{.*}}] +func.func @channel_compose_cast() { + %0 = memref.alloc() : memref<128x64xbf16, 1> + %1 = memref.alloc() : memref<128x64xbf16, 2> + %cast = memref.cast %0 : memref<128x64xbf16, 1> to memref<128x64xbf16, strided<[64, 1], offset: ?>, 1> + air.channel.put @channel[] (%cast[] [] []) : (memref<128x64xbf16, strided<[64, 1], offset: ?>, 1>) + %cast_1 = memref.cast %1 : memref<128x64xbf16, 2> to memref<128x64xbf16, strided<[64, 1], offset: ?>, 2> + %5 = air.channel.get async @channel[] (%cast_1[] [] []) : (memref<128x64xbf16, strided<[64, 1], offset: ?>, 2>) + return +} + +// Memref op chain on air::DmaMemcpyNdOp's src/dst memrefs. +// CHECK: func.func @func0 +// CHECK: %[[CST128:.*]] = arith.constant 128 : index +// CHECK: %[[CST32:.*]] = arith.constant 32 : index +// CHECK: %[[CST8:.*]] = arith.constant 8 : index +// CHECK: %[[CST16:.*]] = arith.constant 16 : index +// CHECK: %[[CST0:.*]] = arith.constant 0 : index +// CHECK: %[[CST1:.*]] = arith.constant 1 : index +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%{{.*}}, %[[CST0]]] [%[[CST8]], %[[CST16]]] [%[[CST16]], %[[CST1]]]) : (memref<1x1x8x16xi32, 1>, memref<8x16xi32>) +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0]], %{{.*}}] [%[[CST16]], %[[CST16]]] [%[[CST32]], %[[CST1]]]) : (memref<1x1x16x16xi32, 1>, memref<16x32xi32>) +// CHECK: air.herd @herd_0 +// CHECK: %[[CST32_0:.*]] = arith.constant 32 : index +// CHECK: %[[CST256_0:.*]] = arith.constant 256 : index +// CHECK: %[[CST4_0:.*]] = arith.constant 4 : index +// CHECK: %[[CST2_0:.*]] = arith.constant 2 : index +// CHECK: %[[CST1_0:.*]] = arith.constant 1 : index +// CHECK: %[[CST16_0:.*]] = arith.constant 16 : index +// CHECK: %[[CST64_0:.*]] = arith.constant 64 : index +// CHECK: %[[CST8_0:.*]] = arith.constant 8 : index +// CHECK: %[[CST128_0:.*]] = arith.constant 128 : index +// CHECK: %[[CST0_0:.*]] = arith.constant 0 : index +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%{{.*}}, %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST2_0]], %[[CST4_0]], %[[CST8_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST8_0]], %[[CST64_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x8x16xi32, 1>) +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0_0]], %{{.*}}, %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST2_0]], %[[CST8_0]], %[[CST8_0]]] [%[[CST256_0]], %[[CST256_0]], %[[CST8_0]], %[[CST128_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x16x16xi32, 1>) +// CHECK: air.dma_memcpy_nd (%{{.*}}[%{{.*}}, %{{.*}}, %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST8_0]], %[[CST16_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST16_0]], %[[CST1_0]]], %{{.*}}[%[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST4_0]], %[[CST2_0]], %[[CST8_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST32_0]], %[[CST8_0]], %[[CST64_0]], %[[CST1_0]]]) : (memref<1x1x8x16xi32, 1>, memref<1x1x2x2x4x8xi32, 2>) +// CHECK: air.dma_memcpy_nd (%{{.*}}[%{{.*}}, %{{.*}}] [%[[CST8]], %[[CST16]]] [%[[CST32]], %[[CST1]]], %{{.*}}[%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]] [%[[CST1]], %[[CST1]], %[[CST8]], %[[CST16]]] [%[[CST128]], %[[CST128]], %[[CST16]], %[[CST1]]]) : (memref<8x32xi32>, memref<1x1x8x16xi32, 1>) + +#map = affine_map<()[s0] -> (s0 * 8)> +#map1 = affine_map<()[s0] -> (s0 * 16)> +func.func @func0(%arg0: memref<8x16xi32>, %arg1: memref<16x32xi32>, %arg2: memref<8x32xi32>) { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + air.launch (%arg3, %arg4) in (%arg5=%c1, %arg6=%c2) args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg2) : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> { + air.segment @segment_0 args(%arg10=%arg3, %arg11=%arg4, %arg12=%arg7, %arg13=%arg8, %arg14=%arg9) : index, index, memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> { + %c1_0 = arith.constant 1 : index + %0 = affine.apply #map()[%arg10] + %1 = affine.apply #map1()[%arg11] + %subview = memref.subview %arg12[%0, 0] [8, 16] [1, 1] : memref<8x16xi32> to memref<8x16xi32, strided<[16, 1], offset: ?>> + %subview_1 = memref.subview %arg13[0, %1] [16, 16] [1, 1] : memref<16x32xi32> to memref<16x16xi32, strided<[32, 1], offset: ?>> + %subview_2 = memref.subview %arg14[%0, %1] [8, 16] [1, 1] : memref<8x32xi32> to memref<8x16xi32, strided<[32, 1], offset: ?>> + %alloc = memref.alloc() : memref<1x1x8x16xi32, 1> + air.dma_memcpy_nd (%alloc[] [] [], %subview[] [] []) : (memref<1x1x8x16xi32, 1>, memref<8x16xi32, strided<[16, 1], offset: ?>>) + %alloc_3 = memref.alloc() : memref<1x1x16x16xi32, 1> + air.dma_memcpy_nd (%alloc_3[] [] [], %subview_1[] [] []) : (memref<1x1x16x16xi32, 1>, memref<16x16xi32, strided<[32, 1], offset: ?>>) + %alloc_4 = memref.alloc() : memref<1x1x8x16xi32, 1> + air.herd @herd_0 tile (%arg15, %arg16) in (%arg17=%c1_0, %arg18=%c1_0) args(%arg19=%alloc, %arg20=%alloc_3, %arg21=%alloc_4) : memref<1x1x8x16xi32, 1>, memref<1x1x16x16xi32, 1>, memref<1x1x8x16xi32, 1> { + %subview_6 = memref.subview %arg19[%arg15, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> + %subview_7 = memref.subview %arg20[0, %arg16, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<1x1x16x16xi32, 1> to memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1> + %subview_8 = memref.subview %arg21[%arg15, %arg16, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> + %alloc_9 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2> + %expand_shape = memref.expand_shape %subview_6 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 2, 4, 2, 8] : memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1> into memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1> + %transpose_10 = memref.transpose %expand_shape (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x4x2x8xi32, strided<[128, 128, 64, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1> + air.dma_memcpy_nd (%alloc_9[] [] [], %transpose_10[] [] []) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x2x2x4x8xi32, strided<[128, 128, 8, 64, 16, 1], offset: ?>, 1>) + %alloc_11 = memref.alloc() : memref<1x1x2x2x8x8xi32, 2> + %expand_shape_12 = memref.expand_shape %subview_7 [[0], [1], [2, 3], [4, 5]] output_shape [1, 1, 2, 8, 2, 8] : memref<1x1x16x16xi32, strided<[256, 256, 16, 1], offset: ?>, 1> into memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1> + %transpose_13 = memref.transpose %expand_shape_12 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d2, d3, d5) : memref<1x1x2x8x2x8xi32, strided<[256, 256, 128, 16, 8, 1], offset: ?>, 1> to memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1> + air.dma_memcpy_nd (%alloc_11[] [] [], %transpose_13[] [] []) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x2x2x8x8xi32, strided<[256, 256, 8, 128, 16, 1], offset: ?>, 1>) + %alloc_14 = memref.alloc() : memref<1x1x2x2x4x8xi32, 2> + %transpose_15 = memref.transpose %alloc_14 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d2, d5) : memref<1x1x2x2x4x8xi32, 2> to memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2> + air.dma_memcpy_nd (%subview_8[] [] [], %transpose_15[] [] []) : (memref<1x1x8x16xi32, strided<[128, 128, 16, 1], offset: ?>, 1>, memref<1x1x2x4x2x8xi32, strided<[128, 128, 32, 8, 64, 1]>, 2>) + memref.dealloc %alloc_9 : memref<1x1x2x2x4x8xi32, 2> + memref.dealloc %alloc_11 : memref<1x1x2x2x8x8xi32, 2> + memref.dealloc %alloc_14 : memref<1x1x2x2x4x8xi32, 2> + } + %subview_5 = memref.subview %alloc_4[0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<1x1x8x16xi32, 1> to memref<8x16xi32, 1> + %transpose = memref.transpose %subview_5 (d0, d1) -> (d0, d1) : memref<8x16xi32, 1> to memref<8x16xi32, strided<[16, 1]>, 1> + air.dma_memcpy_nd (%subview_2[] [] [], %transpose[] [] []) : (memref<8x16xi32, strided<[32, 1], offset: ?>>, memref<8x16xi32, strided<[16, 1]>, 1>) + memref.dealloc %alloc_3 : memref<1x1x16x16xi32, 1> + memref.dealloc %alloc : memref<1x1x8x16xi32, 1> + memref.dealloc %alloc_4 : memref<1x1x8x16xi32, 1> + } + } + return +} + +// CHECK: func.func @func1 +// CHECK: air.herd @herd_0 {{.*}} args(%[[ARG0:.*]]=%{{.*}}, %[[ARG1:.*]]=%{{.*}}) +// CHECK-DAG: %[[CST4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[CST3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[CST8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[CST64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[CST256:.*]] = arith.constant 256 : index +// CHECK-DAG: %[[CST768:.*]] = arith.constant 768 : index +// CHECK-DAG: %[[CST0:.*]] = arith.constant 0 : index +// CHECK: air.dma_memcpy_nd (%[[ARG1]][] [] [], %[[ARG0]][%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]] [%[[CST3]], %[[CST3]], %[[CST4]], %[[CST1]], %[[CST8]], %[[CST8]]] [%[[CST768]], %[[CST256]], %[[CST64]], %[[CST8]], %[[CST8]], %[[CST1]]]) : (memref<3x3x4x1x8x8xi8, 2 : i32>, memref<3x3x32x8xi8, 1 : i32>) +// CHECK: } + +func.func @func1() { + %c8 = arith.constant 8 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + air.launch (%arg0, %arg1, %arg2, %arg3) in (%arg4=%c2, %arg5=%c3, %arg6=%c3, %arg7=%c8) { + air.segment @segment_0 { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %alloc = memref.alloc() : memref<3x3x4x1x8x8xi8, 2 : i32> + %alloc_0 = memref.alloc() : memref<3x3x32x8xi8, 1 : i32> + air.herd @herd_0 tile (%arg8, %arg9) in (%arg10=%c4, %arg11=%c1) args(%arg12=%alloc_0, %arg13=%alloc) : memref<3x3x32x8xi8, 1 : i32>, memref<3x3x4x1x8x8xi8, 2 : i32> { + %cast = memref.cast %arg12 : memref<3x3x32x8xi8, 1 : i32> to memref<3x3x32x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1 : i32> + %expand_shape = memref.expand_shape %cast [[0], [1], [2, 3], [4, 5]] output_shape [3, 3, 4, 8, 1, 8] : memref<3x3x32x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1 : i32> into memref<3x3x4x8x1x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32> + %transpose = memref.transpose %expand_shape (d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5) : memref<3x3x4x8x1x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32> to memref<3x3x4x1x8x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32> + air.dma_memcpy_nd (%arg13[] [] [], %transpose[] [] []) : (memref<3x3x4x1x8x8xi8, 2 : i32>, memref<3x3x4x1x8x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32>) + } + } + } + return +} From 3a4d445402156313b5fcab64b922e221f25e9e3a Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 23 Oct 2024 17:06:25 -0700 Subject: [PATCH 3/5] Switch to use replaceOpWithNewOp method for patterns; fixup comments --- mlir/lib/Dialect/AIR/IR/AIRDialect.cpp | 80 ++++++-------------------- 1 file changed, 19 insertions(+), 61 deletions(-) diff --git a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp index 09203abf4..a35144090 100644 --- a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp +++ b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp @@ -1243,35 +1243,26 @@ static LogicalResult ComposeMemrefOp(Value memref, PatternRewriter &rewriter, // Revert the vector of memref ops, as it was built with push_back. std::reverse(memrefOpVec.begin(), memrefOpVec.end()); - // Init. memref type and offsets at the front of the vector of memref ops. + // Init. source memref and offsets at the front of the vector of memref ops. auto constZero = rewriter.create(loc, 0); - MemRefType input_ty; if (auto subviewOp = dyn_cast(memrefOpVec[0])) { - // Init. offsets extractOffsetsFromSubview(subviewOp, rewriter, offsets); - // Init. memref type - input_ty = subviewOp.getSourceType(); input_memref = subviewOp.getViewSource(); } else if (auto transposeOp = dyn_cast(memrefOpVec[0])) { - // Init. memref type - input_ty = llvm::cast(transposeOp.getIn().getType()); - input_memref = transposeOp.getIn(); - // Init. offsets offsets.clear(); for (unsigned i = 0; i < transposeOp.getPermutation().getNumInputs(); i++) offsets.push_back(constZero); + input_memref = transposeOp.getIn(); } else if (auto viewLikeOp = dyn_cast(memrefOpVec[0])) { - // Init. memref type - input_ty = llvm::cast(viewLikeOp.getViewSource().getType()); - input_memref = viewLikeOp.getViewSource(); - // Init. offsets offsets.clear(); - for (unsigned i = 0; i < input_ty.getRank(); i++) + for (unsigned i = 0; + i < llvm::cast(input_memref.getType()).getRank(); i++) offsets.push_back(constZero); + input_memref = viewLikeOp.getViewSource(); } else return failure(); - // Compose memref type as the memref propagates through the chain of memref + // Compose offsets as the memref type propagates through the chain of memref // ops. for (auto memrefOp : memrefOpVec) { if (auto transposeOp = dyn_cast(memrefOp)) { @@ -1280,7 +1271,6 @@ static LogicalResult ComposeMemrefOp(Value memref, PatternRewriter &rewriter, offsets = applyPermutationMap(transposeOp.getPermutation(), offsets); } else if (auto expandShapeOp = dyn_cast(memrefOp)) { - // Init. offsets for (int i = (int)expandShapeOp.getReassociationIndices().size() - 1; i >= 0; i--) { if (expandShapeOp.getReassociationIndices()[i].size() <= 1) @@ -1299,16 +1289,16 @@ static LogicalResult ComposeMemrefOp(Value memref, PatternRewriter &rewriter, } } - // Memref type at sink memref op. - input_ty = + // Memref type at the sink memref op. + MemRefType sink_memref_ty = llvm::cast(memrefOpVec.back()->getResultTypes().front()); // Compose sizes and strides from the output memref type's layout. - strides = extractStridesFromMemrefType(input_ty, rewriter); - sizes = extractSizesFromMemrefType(input_ty, rewriter); + strides = extractStridesFromMemrefType(sink_memref_ty, rewriter); + sizes = extractSizesFromMemrefType(sink_memref_ty, rewriter); return canonicalizeAIRDmaOperands(rewriter, offsets, sizes, strides, - input_ty); + sink_memref_ty); } // @@ -1321,7 +1311,6 @@ ComposeMemrefOpOnDmaMemcpyNdSrc(DmaMemcpyNdOp op, PatternRewriter &rewriter) { Value memref = op.getSrcMemref(); if (!memref) return failure(); - auto loc = op->getLoc(); Value input_memref; SmallVector offsets, sizes, strides; offsets = op.getSrcOffsets(); @@ -1338,18 +1327,11 @@ ComposeMemrefOpOnDmaMemcpyNdSrc(DmaMemcpyNdOp op, PatternRewriter &rewriter) { strides))) { return failure(); } - - auto newOp = rewriter.create( - loc, op->getResultTypes(), op.getAsyncDependencies(), op.getDstMemref(), + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getAsyncDependencies(), op.getDstMemref(), op.getDstOffsets(), op.getDstSizes(), op.getDstStrides(), input_memref, offsets, sizes, strides); - for (unsigned i = 0; i < op->getNumResults(); i++) { - op->getResult(i).replaceAllUsesWith(newOp->getResult(i)); - } - - rewriter.eraseOp(op); - return success(); } @@ -1359,7 +1341,6 @@ ComposeMemrefOpOnDmaMemcpyNdDst(DmaMemcpyNdOp op, PatternRewriter &rewriter) { Value memref = op.getDstMemref(); if (!memref) return failure(); - auto loc = op->getLoc(); Value input_memref; SmallVector offsets, sizes, strides; offsets = op.getDstOffsets(); @@ -1376,18 +1357,11 @@ ComposeMemrefOpOnDmaMemcpyNdDst(DmaMemcpyNdOp op, PatternRewriter &rewriter) { strides))) { return failure(); } - - auto newOp = rewriter.create( - loc, op->getResultTypes(), op.getAsyncDependencies(), input_memref, + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getAsyncDependencies(), input_memref, offsets, sizes, strides, op.getSrcMemref(), op.getSrcOffsets(), op.getSrcSizes(), op.getSrcStrides()); - for (unsigned i = 0; i < op->getNumResults(); i++) { - op->getResult(i).replaceAllUsesWith(newOp->getResult(i)); - } - - rewriter.eraseOp(op); - return success(); } @@ -1407,7 +1381,6 @@ static LogicalResult ComposeMemrefOpOnChannelPut(ChannelPutOp op, Value memref = op.getMemref(); if (!memref) return failure(); - auto loc = op->getLoc(); // Init. memref type and offsets from memref's defining op's input type Value input_memref; @@ -1426,17 +1399,10 @@ static LogicalResult ComposeMemrefOpOnChannelPut(ChannelPutOp op, strides))) { return failure(); } - - auto newOp = rewriter.create( - loc, op->getResultTypes(), op.getAsyncDependencies(), op.getChanName(), + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getAsyncDependencies(), op.getChanName(), op.getIndices(), input_memref, offsets, sizes, strides); - for (unsigned i = 0; i < op->getNumResults(); i++) { - op->getResult(i).replaceAllUsesWith(newOp->getResult(i)); - } - - rewriter.eraseOp(op); - return success(); } @@ -1455,7 +1421,6 @@ static LogicalResult ComposeMemrefOpOnChannelGet(ChannelGetOp op, Value memref = op.getMemref(); if (!memref) return failure(); - auto loc = op->getLoc(); // Init. memref type and offsets from memref's defining op's input type Value input_memref; @@ -1474,17 +1439,10 @@ static LogicalResult ComposeMemrefOpOnChannelGet(ChannelGetOp op, strides))) { return failure(); } - - auto newOp = rewriter.create( - loc, op->getResultTypes(), op.getAsyncDependencies(), op.getChanName(), + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getAsyncDependencies(), op.getChanName(), op.getIndices(), input_memref, offsets, sizes, strides); - for (unsigned i = 0; i < op->getNumResults(); i++) { - op->getResult(i).replaceAllUsesWith(newOp->getResult(i)); - } - - rewriter.eraseOp(op); - return success(); } From 1af19c57570f2abf5a8e6e4c1fb00912e3af70cb Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 23 Oct 2024 17:18:29 -0700 Subject: [PATCH 4/5] Fixup typo --- mlir/lib/Dialect/AIR/IR/AIRDialect.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp index a35144090..174350d5b 100644 --- a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp +++ b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp @@ -1246,19 +1246,19 @@ static LogicalResult ComposeMemrefOp(Value memref, PatternRewriter &rewriter, // Init. source memref and offsets at the front of the vector of memref ops. auto constZero = rewriter.create(loc, 0); if (auto subviewOp = dyn_cast(memrefOpVec[0])) { - extractOffsetsFromSubview(subviewOp, rewriter, offsets); input_memref = subviewOp.getViewSource(); + extractOffsetsFromSubview(subviewOp, rewriter, offsets); } else if (auto transposeOp = dyn_cast(memrefOpVec[0])) { + input_memref = transposeOp.getIn(); offsets.clear(); for (unsigned i = 0; i < transposeOp.getPermutation().getNumInputs(); i++) offsets.push_back(constZero); - input_memref = transposeOp.getIn(); } else if (auto viewLikeOp = dyn_cast(memrefOpVec[0])) { + input_memref = viewLikeOp.getViewSource(); offsets.clear(); for (unsigned i = 0; i < llvm::cast(input_memref.getType()).getRank(); i++) offsets.push_back(constZero); - input_memref = viewLikeOp.getViewSource(); } else return failure(); From 554624f5ec4165c78d10bf67ff1cca30c4601ae4 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 23 Oct 2024 17:19:07 -0700 Subject: [PATCH 5/5] Call canonicalization pattern at the end of copyToDma --- mlir/lib/Conversion/ConvertToAIRPass.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index 86fd9e160..7134f62a9 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -1171,6 +1171,10 @@ struct CopyToDmaPass : public air::impl::CopyToDmaBase { LLVM_DEBUG(llvm::outs() << "output\n"); LLVM_DEBUG(module.print(llvm::outs())); + + RewritePatternSet pattern(context); + air::DmaMemcpyNdOp::getCanonicalizationPatterns(pattern, context); + (void)applyPatternsAndFoldGreedily(module, std::move(pattern)); } };