Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tl.dot with batched (3D) input is working only in emulation mode when os.environ['TRITON_INTERPRET'] = '1' is set. #5620

Open
warpuv opened this issue Jan 15, 2025 · 0 comments
Labels

Comments

@warpuv
Copy link

warpuv commented Jan 15, 2025

Describe the bug

tl.dot with batched (3D) input is working only in emulation mode when os.environ['TRITON_INTERPRET'] = '1' is set.
It may be issue with tl.load and not with tl.dot.

None of the possible values (1, 2, 4, 8, 16) of BLOCK_SIZE_BATCH are working:
If you remove the for _ in range(0, k_size, BLOCK_SIZE) loop and leave only body of it when it works fine.

Problematic part:

acc = tl.zeros((BLOCK_SIZE_BATCH, BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)
for _ in range(0, k_size, BLOCK_SIZE):
    a = tl.load(offs_a_mat)
    b = tl.load(offs_b_mat)
    acc = tl.dot(a, b, acc)

    offs_a_mat += BLOCK_SIZE * stride_a2
    offs_b_mat += BLOCK_SIZE * stride_b1

If you move one of the tl.load statements outside the loop, when the error does not occur:

a = tl.load(offs_a_mat)
for _ in range(0, k_size, BLOCK_SIZE):
    b = tl.load(offs_b_mat)
    acc = tl.dot(a, b, acc)

    #offs_a_mat += BLOCK_SIZE * stride_a2
    offs_b_mat += BLOCK_SIZE * stride_b1

Full code to reproduce:

import os

#os.environ['TRITON_INTERPRET'] = '1'
from functools import partial

import torch
import triton
import triton.language as tl


@triton.jit
def bmm_dot_kernel(
    k_size,
    a_ptr, b_ptr, c_ptr,
    stride_a0, stride_a1, stride_a2,
    stride_b0, stride_b1, stride_b2,
    stride_c0, stride_c1, stride_c2,
    BLOCK_SIZE: tl.constexpr,
    BLOCK_SIZE_BATCH: tl.constexpr,
):
    pid_m, pid_n, pid_batch = tl.program_id(0), tl.program_id(1), tl.program_id(2)

    rm_vec        = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    rn_vec        = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    rk_vec        = tl.arange(0, BLOCK_SIZE)

    rbatch_vec_   = pid_batch * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
    rbatch_vec3d  = tl.expand_dims(tl.expand_dims(rbatch_vec_, 1), 1)
    
    offs_a_mat = a_ptr + rbatch_vec3d*stride_a0 + tl.expand_dims(rk_vec, 1)*stride_a1 + tl.expand_dims(rn_vec, 0)*stride_a2
    offs_b_mat = b_ptr + rbatch_vec3d*stride_b0 + tl.expand_dims(rk_vec, 1)*stride_b1 + tl.expand_dims(rn_vec, 0)*stride_b2

    acc = tl.zeros((BLOCK_SIZE_BATCH, BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)
    for _ in range(0, k_size, BLOCK_SIZE):
        a = tl.load(offs_a_mat)
        b = tl.load(offs_b_mat)
        acc = tl.dot(a, b, acc)

        offs_a_mat += BLOCK_SIZE * stride_a2
        offs_b_mat += BLOCK_SIZE * stride_b1

    c = c_ptr + rbatch_vec3d*stride_c0 + tl.expand_dims(rm_vec, 1)*stride_c1 + tl.expand_dims(rn_vec, 0)*stride_c2
    
    tl.store(c, acc)

def bmm_dot_op(a: torch.Tensor, b: torch.Tensor, matmul_k_fn):
    batch_size=a.shape[0]
    k_size = a.shape[2]
    m = a.shape[1]
    n = b.shape[2]
    assert m == n # m and n are equal in this test

    BLOCK_SIZE=16 # block sizes for m, n, k are all equal
    BLOCK_SIZE_BATCH=2 # can be 1, 2, 4, 8, 16

    c = torch.zeros(batch_size, m, n, device=a.device, dtype=a.dtype)

    grid = (1, 1, batch_size // BLOCK_SIZE_BATCH)
    matmul_k_fn[grid](
        k_size,
        a, b, c,
        a.stride(0), a.stride(1), a.stride(2),
        b.stride(0), b.stride(1), b.stride(2),
        c.stride(0), c.stride(1), c.stride(2),
        BLOCK_SIZE=BLOCK_SIZE,
        BLOCK_SIZE_BATCH=BLOCK_SIZE_BATCH
    )
    return c


bmm_dot = partial(bmm_dot_op, matmul_k_fn=bmm_dot_kernel)

if __name__ == '__main__':
    dtype=torch.float16
    k = 128 # can be 16, 32, 128, 256... etc.
    
    mn = 16
    batch_size = 16
    a = torch.ones(batch_size, mn, k, dtype=dtype, device='cuda')
    b = torch.ones(batch_size, k, mn, dtype=dtype, device='cuda')

    c_ref = torch.bmm(a, b)
    c_test = bmm_dot(a, b)

    print(c_test)
    allclose = torch.allclose(c_test, c_ref, rtol=0.1, atol=0.1)
    assert allclose
    print(f'Test COMPLETE! {allclose=}')

Stacktrace from Triton 3.1 installed from pip repo:

CUDA_VISIBLE_DEVICES=4 python3 ./triton_kernels/bmm_dot.py 
loc("/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py":35:20): error: offsets must have the same rank as input
Traceback (most recent call last):
  File "/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py", line 83, in <module>
    c_test = bmm_dot(a, b)
             ^^^^^^^^^^^^^
  File "/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py", line 59, in bmm_dot_op
    matmul_k_fn[grid](
  File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
             ^^^^^^^^^^^^^
  File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/compiler/compiler.py", line 282, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 317, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 189, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

Stacktrace from the latest Triton from the main branch, commit hash: 59bd2dc :

CUDA_VISIBLE_DEVICES=4 python3 ./triton_kernels/bmm_dot.py 
/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py:35:20: error: offsets must have the same rank as input
        a = tl.load(offs_a_mat)
                   ^
/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py:35:20: note: see current operation: %81 = "ttg.memdesc_subview"(%76, %79, %80) : (!ttg.memdesc<2x16x16xf16, #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [2, 0, 1], hasLeadingOffset = false}>, #ttg.shared_memory, mutable, 2x2x16x16>, i32, i32) -> !ttg.memdesc<2x16x16xf16, #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [2, 0, 1], hasLeadingOffset = false}>, #ttg.shared_memory, mutable, 2x2x16x16>
Traceback (most recent call last):
  File "/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py", line 83, in <module>
    c_test = bmm_dot(a, b)
             ^^^^^^^^^^^^^
  File "/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py", line 59, in bmm_dot_op
    matmul_k_fn[grid](
  File "/workspace/src/external/triton/python/triton/runtime/jit.py", line 336, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/src/external/triton/python/triton/runtime/jit.py", line 563, in run
    kernel = self.compile(src, target=target, options=options.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/src/external/triton/python/triton/compiler/compiler.py", line 281, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/src/external/triton/python/triton/backends/nvidia/compiler.py", line 409, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/src/external/triton/python/triton/backends/nvidia/compiler.py", line 270, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

Environment details

nvidia-smi: Driver Version: 470.161.03 CUDA Version: 11.4
GPU: NVIDIA A100

@warpuv warpuv added the bug label Jan 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant