Skip to content

Commit

Permalink
Add support for MFMA layout to view_slice op (#442)
Browse files Browse the repository at this point in the history
Co-authored-by: Ognjen <[email protected]>
  • Loading branch information
oplavsic and Ognjen authored Jan 3, 2024
1 parent 6a52056 commit bcea305
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 44 deletions.
24 changes: 11 additions & 13 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,25 +894,22 @@ struct ViewSliceOpConversion
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::ViewSliceOp>(typeConverter,
benefit) {}

LogicalResult
processBlockedLayout(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
LogicalResult processLayout(triton::gpu::ViewSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(
srcLayout &&
"Currently only blocked layout is supported in view_slice instruction");
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSource(), rewriter, srcTy);

auto elemsPerThread = mlir::triton::gpu::getElemsPerThread(srcTy);
auto sizePerThread = srcLayout.getSizePerThread();
auto sizePerThread = mlir::triton::gpu::getSizePerThread(srcLayout);
auto totalSizePerThread = sizePerThread[0] * sizePerThread[1];
auto order = srcLayout.getOrder();
auto shapePerCTA = getShapePerCTATile(srcLayout, srcShape);
auto order = mlir::triton::gpu::getOrder(srcLayout);
auto shapePerCTA =
mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape);
shapePerCTA[0] = std::min(srcShape[0], (long)shapePerCTA[0]);
shapePerCTA[1] = std::min(srcShape[1], (long)shapePerCTA[1]);

Expand Down Expand Up @@ -969,10 +966,11 @@ struct ViewSliceOpConversion
ConversionPatternRewriter &rewriter) const override {

auto srcTy = op.getSource().getType().dyn_cast<RankedTensorType>();
if (srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>()) {
return processBlockedLayout(op, adaptor, rewriter);
if (srcTy.getEncoding().isa<BlockedEncodingAttr>() ||
srcTy.getEncoding().isa<MfmaEncodingAttr>()) {
return processLayout(op, adaptor, rewriter);
} else {
assert(false && "unsupported layout in viewSlice");
assert(false && "Unsupported layout in viewSlice.");
return failure();
}
}
Expand Down
71 changes: 40 additions & 31 deletions python/test/unit/language/test_core_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2933,50 +2933,58 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
assert torch.equal(z, x)


layouts = [
view_layouts = [
BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False),
]

@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", [[256, 128, 256, 32, 0, 0], [256, 256, 128, 64, 64, 128], [128, 128, 128, 32, 0, 0], [128, 128, 128, 32, 0, 64]])
blocked_layouts = [
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
]
@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", [[256, 128, 256, 32, 0, 0], [128, 128, 128, 32, 0, 0], [128, 128, 128, 32, 0, 64]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, src_layout, device='cuda'):
@pytest.mark.parametrize("view_layout", view_layouts)
@pytest.mark.parametrize("blocked_layout", blocked_layouts)
def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout, device='cuda'):
if torch.version.hip is None:
pytest.skip("view_slice is AMD specific instruction.")

ir = f"""
#src = {src_layout}
#blocked = {blocked_layout}
#view_layout = {view_layout}
""" + """
module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + f""" : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%43 = tt.expand_dims %42 {{axis = 1 : i32}} : (tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M_tile_size}x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
%44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #src>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{M}xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x{M}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
%33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%34 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #src>
%37 = tt.expand_dims %33 {{axis = 0 : i32}} : (tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N_tile_size}xi32, #src>
%38 = tt.broadcast %37 : (tensor<1x{N_tile_size}xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src>
%39 = tt.broadcast %44 : (tensor<{M_tile_size}x1xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src>
%40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #src>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src>
%12 = triton_gpu.view_slice %11[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #src> to tensor<{M_tile_size}x{N_tile_size}xf16, #src>
%13 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #src>, tensor<{M_tile_size}x{N_tile_size}xi32, #src>
tt.store %13, %12 : tensor<{M_tile_size}x{N_tile_size}xf16, #src>
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
%cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #blocked>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
%43 = tt.expand_dims %42 {{axis = 1 : i32}} : (tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M_tile_size}x1xi32, #blocked>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked>
%44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{M}xi32, #blocked>
%7 = tt.broadcast %6 : (tensor<1x{M}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked>
%33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%34 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #blocked>
%37 = tt.expand_dims %33 {{axis = 0 : i32}} : (tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N_tile_size}xi32, #blocked>
%38 = tt.broadcast %37 : (tensor<1x{N_tile_size}xi32, #blocked>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked>
%39 = tt.broadcast %44 : (tensor<{M_tile_size}x1xi32, #blocked>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked>
%40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #blocked>
%12 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #blocked>) -> tensor<{M}x{N}xf16, #view_layout>
%13 = triton_gpu.view_slice %12[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #view_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout>
%14 = triton_gpu.convert_layout %13 : (tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout>) -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked>
%15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked>
tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}xf16, #blocked>
tt.return
}}
}}
Expand All @@ -2998,6 +3006,7 @@ def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile
kernel[(1, 1, 1)](x.data_ptr(), z_tri)
np.testing.assert_equal(z_numpy, to_numpy(z_tri))


if torch.version.hip is not None and _get_warp_size() == 64:
layouts = [
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
Expand Down

0 comments on commit bcea305

Please sign in to comment.