diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index e0288941c908..84025b2ad718 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -105,6 +105,11 @@ unsigned getNumCTAs(Attribute layout); bool isaDistributedLayout(Attribute layout); +bool sameBlockedEncodings(BlockedEncodingAttr blockedA, + BlockedEncodingAttr blockedB); + +bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB); + bool isSharedEncoding(Value value); bool isExpensiveCat(CatOp cat, Attribute targetEncoding); diff --git a/include/triton/Dialect/TritonGPU/IR/Traits.h b/include/triton/Dialect/TritonGPU/IR/Traits.h index 44def95804da..15c033c1086b 100644 --- a/include/triton/Dialect/TritonGPU/IR/Traits.h +++ b/include/triton/Dialect/TritonGPU/IR/Traits.h @@ -14,6 +14,7 @@ namespace OpTrait { // instantiated/duplicated. namespace impl { LogicalResult verifyResultsAreSharedEncoding(Operation *op); +LogicalResult verifyOperandAndResultHaveSameEncoding(Operation *op); } // namespace impl template @@ -25,6 +26,14 @@ class ResultsAreSharedEncoding } }; +template +class OperandAndResultHaveSameEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyOperandAndResultHaveSameEncoding(op); + } +}; } // namespace OpTrait } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 9f8e782414b6..85b0588fcfa1 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -14,6 +14,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">; +def OperandAndResultHaveSameEncoding: NativeOpTrait<"OperandAndResultHaveSameEncoding">; class TTG_Op traits = []> : Op; @@ -179,6 +180,67 @@ def TTG_InsertSliceOp : TTG_Op<"insert_slice", } +def TTG_ViewSliceOp : TTG_Op<"view_slice", + [AttrSizedOperandSegments, + OperandAndResultHaveSameEncoding, + Pure, + OffsetSizeAndStrideOpInterface + ]> { + let summary = "view slice operation"; + let description = [{ + Represents view of the slice of the tensor in registers. Syntax of the operation is the same + as for extract_slice op. However, unlike 'extract_slice' which slices in shared memory, + 'view_slice' specifically slices within registers. + Slice of the tensor is required to have the same layout as the original tensor. + In a way, semantics of the 'view_slice' operation is a combination of the 'extract_slice' and 'view' operations semantics. + }]; + + let arguments = (ins + AnyRankedTensor:$source, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides + ); + let results = (outs AnyRankedTensor:$result); + + let builders = [ + // Build an ExtractSliceOp with mixed static and dynamic entries and custom + // result type. If the type passed is nullptr, it is inferred. + OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, + "ArrayRef":$offsets, "ArrayRef":$sizes, + "ArrayRef":$strides, + CArg<"ArrayRef", "{}">:$attrs)>, + ]; + + let extraClassDeclaration = [{ + /// Return the number of leading operands before the `offsets`, `sizes` and + /// and `strides` operands. + static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } + + /// Returns the type of the base tensor operand. + RankedTensorType getSourceType() { + return getSource().getType().cast(); + } + + std::array getArrayAttrMaxRanks() { + unsigned rank = getSourceType().getRank(); + return {rank, rank, rank}; + } + }]; + + let assemblyFormat = [{ + $source `` + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) + attr-dict `:` type($source) `to` type($result) + }]; +} + + def TTG_ExtractSliceOp : TTG_Op<"extract_slice", [AttrSizedOperandSegments, ResultsAreSharedEncoding, diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 4a3260ac9931..d3f7f87400a1 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -849,6 +849,135 @@ struct ExtractSliceOpConversion } }; +// clang-format off +/*** + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # + # WO # W1 # | # + # # # | # + # # # # # | # + # W2 # W3 # .... | # + # # # | SkipElems # + # # # # # | # + # | # + # Slice | # + # . / \ | # + # . / \ | # + # . / \| # + # # # # # # # + # # W0 # W1 # # + # # # # # + # # # # # # tensorStride # + # # W2 # W3 # --------------------------------# + # # # # # + # # # # # # # + # tensorStride # W0 # W1 # # + # ---------------------------------- # # # # + # # # # # # # + # # W2 # W3 # # + # # # # # + # # # # # # ---> lastIdx # + # . # + # . # + # . # + # # + # # + # # + # # + # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +***/ +// clang-format on +struct ViewSliceOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using OpAdaptor = typename triton::gpu::ViewSliceOp::Adaptor; + explicit ViewSliceOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, + benefit) {} + + LogicalResult + processBlockedLayout(triton::gpu::ViewSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + auto srcTy = op.getSource().getType().dyn_cast(); + auto srcLayout = srcTy.getEncoding().dyn_cast(); + assert( + srcLayout && + "Currently only blocked layout is supported in view_slice instruction"); + auto srcShape = srcTy.getShape(); + auto resultTy = op.getType().template cast(); + auto vals = this->getTypeConverter()->unpackLLElements( + loc, adaptor.getSource(), rewriter, srcTy); + + auto elemsPerThread = mlir::triton::gpu::getElemsPerThread(srcTy); + auto sizePerThread = srcLayout.getSizePerThread(); + auto totalSizePerThread = sizePerThread[0] * sizePerThread[1]; + auto order = srcLayout.getOrder(); + auto shapePerCTA = getShapePerCTATile(srcLayout, srcShape); + shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]); + shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]); + + auto offsets = op.getStaticOffsets(); + auto sizes = op.getStaticSizes(); + + // ViewSlice only supports slicing where offsets and sizes are multiples of + // shapePerCTA. This condition ensures that slice has the same layout as the + // original tensor. + assert(offsets[0] % shapePerCTA[0] == 0); + assert(offsets[1] % shapePerCTA[1] == 0); + assert(sizes[0] % shapePerCTA[0] == 0); + assert(sizes[1] % shapePerCTA[1] == 0); + assert(op.hasUnitStride() && + "Only unit stride supported by ViewSliceOpConversion"); + + // Calculate offsets and sizes in terms of CTA units. + std::vector CTAOffsets{offsets[0] / shapePerCTA[0], + offsets[1] / shapePerCTA[1]}; + std::vector CTASizes{sizes[0] / shapePerCTA[0], + sizes[1] / shapePerCTA[1]}; + std::vector CTAPerShape{srcShape[0] / shapePerCTA[0], + srcShape[1] / shapePerCTA[1]}; + + SmallVector resultVals; + // The diagram above illustrates the graphical representation of the + // skipElems, tensorStride, and lastIdx variables. + auto skipElems = CTAOffsets[order[1]] * + (elemsPerThread[order[0]] * sizePerThread[order[1]]) + + CTAOffsets[order[0]] * totalSizePerThread; + auto tensorStride = + (CTAPerShape[order[0]] - CTASizes[order[0]]) * totalSizePerThread; + auto lastIdx = + (CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * + elemsPerThread[order[0]] * sizePerThread[order[1]] + + (CTAOffsets[order[0]] + CTASizes[order[0]]) * totalSizePerThread; + + assert(lastIdx <= vals.size()); + for (int i = skipElems; i < lastIdx; i += tensorStride) { + for (int j = 0; j < totalSizePerThread * CTASizes[order[0]]; ++j, ++i) { + assert(i < lastIdx); + resultVals.push_back(vals[i]); + } + } + + Value ret = this->getTypeConverter()->packLLElements(loc, resultVals, + rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } + + LogicalResult + matchAndRewrite(triton::gpu::ViewSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto srcTy = op.getSource().getType().dyn_cast(); + if (srcTy.getEncoding().dyn_cast()) { + return processBlockedLayout(op, adaptor, rewriter); + } else { + assert(false && "unsupported layout in viewSlice"); + return failure(); + } + } +}; + struct AsyncWaitOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -954,6 +1083,7 @@ void populateTritonGPUToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, moduleAllocation, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index bd0e6e36a54e..d8e1556f1368 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -691,6 +691,59 @@ bool isaDistributedLayout(Attribute layout) { layout.isa() || layout.isa(); } +bool sameBlockedEncodings(BlockedEncodingAttr blockedA, + BlockedEncodingAttr blockedB) { + auto sizePerThreadA = blockedA.getSizePerThread(); + auto threadsPerWarpA = blockedA.getThreadsPerWarp(); + auto warpsPerCTAA = blockedA.getWarpsPerCTA(); + auto orderA = blockedA.getOrder(); + size_t rankA = orderA.size(); + + auto sizePerThreadB = blockedB.getSizePerThread(); + auto threadsPerWarpB = blockedB.getThreadsPerWarp(); + auto warpsPerCTAB = blockedB.getWarpsPerCTA(); + auto orderB = blockedB.getOrder(); + size_t rankB = orderB.size(); + + if (rankA != rankB) { + return false; + } + for (size_t i = 0; i < rankA; ++i) { + if (sizePerThreadA[i] != sizePerThreadB[i] || + threadsPerWarpA[i] != threadsPerWarpB[i] || + warpsPerCTAA[i] != warpsPerCTAB[i] || orderA[i] != orderB[i]) { + return false; + } + } + return true; +} + +bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB) { + auto nonKDimA = mfmaA.getNonKDim(); + auto warpsPerCTAA = mfmaA.getWarpsPerCTA(); + auto isTransposedA = mfmaA.getIsTransposed(); + + auto nonKDimB = mfmaB.getNonKDim(); + auto warpsPerCTAB = mfmaB.getWarpsPerCTA(); + auto isTransposedB = mfmaB.getIsTransposed(); + + if (nonKDimA != nonKDimB || isTransposedA != isTransposedB) { + return false; + } + + if (warpsPerCTAA.size() != warpsPerCTAB.size()) { + return false; + } + + auto rank = warpsPerCTAA.size(); + for (size_t i = 0; i < rank; ++i) { + if (warpsPerCTAA[i] != warpsPerCTAB[i]) { + return false; + } + } + return true; +} + bool isSharedEncoding(Value value) { auto type = value.getType(); if (auto tensorType = type.dyn_cast()) { diff --git a/lib/Dialect/TritonGPU/IR/Traits.cpp b/lib/Dialect/TritonGPU/IR/Traits.cpp index 5d5778ec3e68..0d5d42017fd4 100644 --- a/lib/Dialect/TritonGPU/IR/Traits.cpp +++ b/lib/Dialect/TritonGPU/IR/Traits.cpp @@ -12,3 +12,51 @@ mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) { return success(); }; + +mlir::LogicalResult +mlir::OpTrait::impl::verifyOperandAndResultHaveSameEncoding(Operation *op) { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) { + return failure(); + } + + auto operandType = op->getOperand(0).getType().dyn_cast(); + auto resultType = op->getResult(0).getType().dyn_cast(); + + if (!operandType || !resultType) { + return failure(); + } + auto operandLayout = operandType.getEncoding(); + auto resultLayout = resultType.getEncoding(); + + if (auto blockedLayoutSrc = + dyn_cast(operandLayout)) { + auto blockedLayoutRes = + dyn_cast(resultLayout); + if (!blockedLayoutRes) { + return op->emitOpError() + << "requires operand and result to have same layout"; + } + + if (!triton::gpu::sameBlockedEncodings(blockedLayoutSrc, + blockedLayoutRes)) { + return op->emitOpError() + << "requires operand and result to have same layout"; + } + } else if (auto mfmaLayoutSrc = + dyn_cast(operandLayout)) { + auto mfmaLayoutRes = dyn_cast(resultLayout); + if (!mfmaLayoutRes) { + return op->emitOpError() + << "requires operand and result to have same layout"; + } + if (!triton::gpu::sameMfmaEncodings(mfmaLayoutSrc, mfmaLayoutRes)) { + return op->emitOpError() + << "requires operand and result to have same layout"; + } + } else { + assert(false && + "Unexpected Layout in verifyOperandAndResultHaveSmeEncoding"); + } + + return success(); +}; diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 314021c9a5de..f16a398c77a5 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -2933,6 +2933,71 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'): assert torch.equal(z, x) +layouts = [ + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] + +@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", [[256, 128, 256, 32, 0, 0], [256, 256, 128, 64, 64, 128], [128, 128, 128, 32, 0, 0], [128, 128, 128, 32, 0, 64]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", layouts) +def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, src_layout, device='cuda'): + if torch.version.hip is None: + pytest.skip("view_slice is AMD specific instruction.") + + ir = f""" +#src = {src_layout} +""" + """ +module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + f""" : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : (tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M_tile_size}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{M}xi32, #src> + %7 = tt.broadcast %6 : (tensor<1x{M}xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %34 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #src> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : (tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N_tile_size}xi32, #src> + %38 = tt.broadcast %37 : (tensor<1x{N_tile_size}xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src> + %39 = tt.broadcast %44 : (tensor<{M_tile_size}x1xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src> + %12 = triton_gpu.view_slice %11[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #src> to tensor<{M_tile_size}x{N_tile_size}xf16, #src> + %13 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #src>, tensor<{M_tile_size}x{N_tile_size}xi32, #src> + tt.store %13, %12 : tensor<{M_tile_size}x{N_tile_size}xf16, #src> + tt.return + }} +}} +""" + + x_numpy = numpy_random((M, N), dtype_str=dtype) + z_numpy = x_numpy[M_tile_offset:M_tile_offset + M_tile_size, N_tile_offset:N_tile_offset + N_tile_size] + x = to_triton(x_numpy) + # write the IR to a temporary file using mkstemp + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + z = np.zeros((M_tile_size, N_tile_size)).astype('float16') + z_tri = torch.tensor(z, device=device) + + kernel[(1, 1, 1)](x.data_ptr(), z_tri) + np.testing.assert_equal(z_numpy, to_numpy(z_tri)) + if torch.version.hip is not None and _get_warp_size() == 64: layouts = [ MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),