From 407981e24d577236b6ed5883e12adba32150191d Mon Sep 17 00:00:00 2001 From: Guo-Peilin Date: Fri, 13 Oct 2023 07:36:11 +0000 Subject: [PATCH] update --- .../data/matmul_nn_s_f16_gpu_schedule_1.mlir | 4 +- .../TransformOps/TransformOpsExt.cc | 50 +++++++++---------- .../TransformOps/TransformOpsExt.td | 4 +- .../disc_lower_gpu_ops_to_nvvm_ops.cc | 45 +++++++++++++++++ 4 files changed, 73 insertions(+), 30 deletions(-) diff --git a/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule_1.mlir b/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule_1.mlir index 6b3287fb46a..9759e5d7439 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule_1.mlir +++ b/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule_1.mlir @@ -39,9 +39,9 @@ transform.sequence failures(propagate) { // 1. use register to cache the result of ldmatrix // 2. use register to cache the result of mma's accumulation result // 3. store the final result from reg to smem and to gmem - // 4. use padding for output smem matrix to avoid bank conflict + // 4. use padding for output smem matrix to avoid bank conflict` %mma = transform.structured.match ops{["nvgpu.mma.sync"]} in %5 : (!transform.any_op) -> !transform.any_op - transform.disc.move_data_to_register %mma by block_mn_shape = [128, 128] smem_padding = 8 bytes = 2: (!transform.any_op) -> () + transform.disc.move_data_to_register %mma by block_mn_shape = [128, 128] smem_padding = 8 : (!transform.any_op) -> () transform.disc.apply_licm %5 : !transform.any_op transform.disc.apply_dce %5 : !transform.any_op transform.disc.apply_cse %5 : !transform.any_op diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc index fb371f52795..12f1b21477d 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc @@ -4506,7 +4506,7 @@ transform_dialect::DISCExpandTransferRWToMemrefCopy::applyToOne( /*dstIndices*/ ValueRange{offsetX, offsetY}, src, /*srcIndices*/ ValueRange{offsetX, offsetY}, /*dstElements*/ b.getIndexAttr(kThreadCopyBytes / 2), - /*srcElements*/ nullptr, // TODO: add support for dynamic shape + /*srcElements*/ srcElements, /*bypassL1*/ b.getUnitAttr()); tokens.push_back(token); } @@ -4973,10 +4973,6 @@ transform_dialect::DISCMoveDataToRegister::applyToOne( ldMatrixB.getRes().getType(), FloatAttr::get(elementType, 0.0))); Value matrixCReg = b.create(DenseElementsAttr::get( mma.getRes().getType(), FloatAttr::get(elementType, 0.0))); - Value matrixBReg1 = b.create(DenseElementsAttr::get( - ldMatrixB.getRes().getType(), FloatAttr::get(elementType, 0.0))); - Value matrixCReg1 = b.create(DenseElementsAttr::get( - mma.getRes().getType(), FloatAttr::get(elementType, 0.0))); SmallVector outputTypes(2, b.getIndexType()); SmallVector shape; @@ -5091,6 +5087,15 @@ transform_dialect::DISCMoveDataToRegister::applyToOne( auto mmaMLoopIV = newMmaMLoop.getInductionVar(); auto mmaMShapeValue = b.create(mmaMShape); auto mIndexValue = b.create(mmaMLoopIV, mmaMShapeValue); + // load A from reg + for (unsigned i = 0; i < 4; ++i) { + matrixAReg = b.create( + b.create( + VectorType::get({vector_width}, elementType), aWarpRegAlloc, + ValueRange{mIndexValue, kIndexValue, + b.create(i), zeroIndex}), + matrixAReg, (int64_t[]){i}); + } auto newMmaNLoop = b.create( zeroIndex, b.create(warpNShape), b.create(mmaNShape)); @@ -5107,15 +5112,6 @@ transform_dialect::DISCMoveDataToRegister::applyToOne( b.create(i), zeroIndex}), matrixCReg, (int64_t[]){i}); } - // load A from reg - for (unsigned i = 0; i < 4; ++i) { - matrixAReg = b.create( - b.create( - VectorType::get({vector_width}, elementType), aWarpRegAlloc, - ValueRange{mIndexValue, kIndexValue, - b.create(i), zeroIndex}), - matrixAReg, (int64_t[]){i}); - } // load B from reg for (unsigned i = 0; i < 2; ++i) { matrixBReg = b.create( @@ -5139,6 +5135,7 @@ transform_dialect::DISCMoveDataToRegister::applyToOne( (void)mlir::loopUnrollByFactor(newMmaNLoop, warpNShape / mmaNShape); (void)mlir::loopUnrollByFactor(newMmaMLoop, warpMShape / mmaMShape); (void)mlir::loopUnrollByFactor(newMmaKLoop, warpKShape / mmaKShape); + // no longer need origin mmaMLoop for (auto user : mmaMLoop->getUsers()) user->erase(); mmaMLoop.erase(); @@ -5196,6 +5193,7 @@ transform_dialect::DISCMoveDataToRegister::applyToOne( Value cWarpSmemCol = b.create( cWarpSmemColOffset, b.create( cWarpSmemColMap, ValueRange{mmaNLoopIV, laneId})); + // TODO: store 4xf16 rather than 2x2xf16 b.create( b.create( VectorType::get({vector_width}, elementType), cWarpRegAlloc, @@ -5232,10 +5230,7 @@ transform_dialect::DISCMoveDataToRegister::applyToOne( // expand genericOp to loop b.setInsertionPoint(genericOp); - // TODO: theoretically, each thread should copy 16 bytes to achieve best - // performance, however there is bug when lowering mlir code to ptx code - // during converting load/store-vector if we set vector_width to 16 - int64_t cpBytesPerThread = getBytes(); + int64_t cpBytesPerThread = 16; // 128 bits int64_t cpElementsPerThread = cpBytesPerThread / 2; int64_t cpThreadsPerRow = blockNShape / cpElementsPerThread; int64_t cpRowsPerBlock = 128 * cpElementsPerThread / blockNShape; @@ -5256,12 +5251,17 @@ transform_dialect::DISCMoveDataToRegister::applyToOne( auto offsetX = b.create( iv, b.create( threadId, b.create(cpThreadsPerRow))); - // TODO: enable mask store - b.create( - b.create( - VectorType::get({cpBytesPerThread / 2}, elementType), cBlockSmemAlloc, - ValueRange{offsetX, offsetY}), - outputSubView, ValueRange{offsetX, offsetY}); + auto dimM = b.create(outputSubView.getSource(), + b.create(0)); + auto dimN = b.create(outputSubView.getSource(), + b.create(1)); + // TODO: enable mask vector store + auto vec8xf16 = b.create( + VectorType::get({cpElementsPerThread}, elementType), cBlockSmemAlloc, + ValueRange{offsetX, offsetY}); + auto vecStore = b.create(vec8xf16, outputSubView, + ValueRange{offsetX, offsetY}); + vecStore->setAttr("alignment", IntegerAttr::get(b.getI32Type(), 16)); b.setInsertionPointAfter(smemToGmemLoop); b.create(); (void)mlir::loopUnrollByFactor(smemToGmemLoop, blockMShape / cpRowsPerBlock); @@ -5276,7 +5276,7 @@ transform_dialect::DISCMoveDataToRegister::applyToOne( "greedy pattern applicatin failed"); } - // Delete any transfer_write op + // Delete any remain transfer_write op Operation* transferWrite = nullptr; parallelOp->walk([&](memref::AllocOp alloc) { if (llvm::hasSingleElement(alloc->getUsers())) { diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.td b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.td index 447f3d82005..0a0b8dc54b8 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.td +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.td @@ -1668,8 +1668,7 @@ def DISCMoveDataToRegister : Op { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + vector::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // Only 1-D vectors can be lowered to LLVM. + VectorType vectorTy = storeOp.getVectorType(); + if (vectorTy.getRank() > 1) return failure(); + auto alignAttr = storeOp->getAttrOfType("alignment"); + if (!alignAttr) return failure(); + storeOp.dump(); + unsigned align = alignAttr.getInt(); + + auto loc = storeOp->getLoc(); + MemRefType memRefTy = storeOp.getMemRefType(); + + // Resolve address. + auto vtype = cast( + this->typeConverter->convertType(storeOp.getVectorType())); + Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), + adaptor.getIndices(), rewriter); + // Casts a strided element pointer to a vector pointer. The vector pointer + // will be in the same address space as the incoming memref type. + Value ptr; + if ((*this->getTypeConverter()).useOpaquePointers()) { + ptr = dataPtr; + } else { + unsigned addressSpace = + *(*this->getTypeConverter()).getMemRefAddressSpace(memRefTy); + auto pType = LLVM::LLVMPointerType::get(vtype, addressSpace); + ptr = rewriter.create(loc, pType, dataPtr); + } + + rewriter.replaceOpWithNewOp( + storeOp, adaptor.getValueToStore(), ptr, align); + return success(); + } +}; + /// A pass that replaces all occurrences of GPU device operations with their /// corresponding NVVM equivalent. /// @@ -125,6 +168,8 @@ struct DiscLowerGpuOpsToNVVMOpsPass llvmPatterns.add( converter, /* PatternBenefit */ 3); llvmPatterns.add(converter); + llvmPatterns.add(converter, + /* PatternBenefit */ 3); arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); populateVectorToLLVMConversionPatterns(converter, llvmPatterns);