Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why triton convert to float8e5 will cause local memory read/write #4769

Open
MARD1NO opened this issue Sep 20, 2024 · 11 comments
Open

Why triton convert to float8e5 will cause local memory read/write #4769

MARD1NO opened this issue Sep 20, 2024 · 11 comments

Comments

@MARD1NO
Copy link
Contributor

MARD1NO commented Sep 20, 2024

I just write a kernel and it contains a x.to(tl.float8e5) , in ncu I found it cause local memory read/store

@ThomasRaoux
Copy link
Collaborator

can you provide a simple kernel example?

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Sep 21, 2024

can you provide a simple kernel example?

Sure,I copy the tutorial gemm kernel as an example, just convert input matrix to tl.float8e5:

import torch

import triton
import triton.language as tl



@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
        GROUP_SIZE_M: tl.constexpr,  #
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
    # See above `Pointer Arithmetic` section for details
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        
        a = a.to(tl.float8e5)
        b = b.to(tl.float8e5)

        # We accumulate along the K dimension.
        accumulator = tl.dot(a, b, accumulator)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    # You can fuse arbitrary activation functions here
    # while the accumulator is still in FP32!
    c = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)



def matmul(a, b):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.

    BLOCK_SIZE_M = 64
    BLOCK_SIZE_N = 128
    BLOCK_SIZE_K = 64  #
    GROUP_SIZE_M = 8

    grid = lambda META: (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), )
    matmul_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        BLOCK_SIZE_M, 
        BLOCK_SIZE_N, 
        BLOCK_SIZE_K, 
        GROUP_SIZE_M
    )
    return c

M = 16
K = 256
N = 256

a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b = b.permute(0, 1).contiguous()
b = b.permute(0, 1)


matmul(a, b)

I profile it in Nsight compute tool, it seems the datatype conversion related to the local memory:
image

image

I test it in h20

@Jokeren
Copy link
Contributor

Jokeren commented Sep 21, 2024

Likely a register spilling problem

@Jokeren
Copy link
Contributor

Jokeren commented Sep 21, 2024

Check your register usage of this kernel

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Sep 21, 2024

Check your register usage of this kernel

I don't think the problem is register spilling, nsight compute shows the register use only 168:
image

maybe the conversion use some non-constant index in array?...
image

@ThomasRaoux
Copy link
Collaborator

this PR might fix it. Can you try:
#4776

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Sep 21, 2024

this PR might fix it. Can you try: #4776

Thanks Thomas, I will try it :D

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Sep 21, 2024

this PR might fix it. Can you try: #4776

this commit still has local memory write

@Jokeren
Copy link
Contributor

Jokeren commented Sep 21, 2024

It has to do with cvt.rn.satfinite.e5m2x2.f16x2. Taking a look now

@MARD1NO
Copy link
Contributor Author

MARD1NO commented Sep 22, 2024

It has to do with cvt.rn.satfinite.e5m2x2.f16x2. Taking a look now

yes, related SASS and ptx shows it use cvt.rn.satfinite.e5m2x2.f16x2

image image

it seems the store local happens after convert to e5m2 from PTX
image

@Jokeren
Copy link
Contributor

Jokeren commented Sep 22, 2024

IMO it's probably caused by nvptx doesn't handle 8-bit vector type well. Let me trigger a discussion and get back to you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants