You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
tl.dot with batched (3D) input is working only in emulation mode when os.environ['TRITON_INTERPRET'] = '1' is set.
It may be issue with tl.load and not with tl.dot.
None of the possible values (1, 2, 4, 8, 16) of BLOCK_SIZE_BATCH are working:
If you remove the for _ in range(0, k_size, BLOCK_SIZE) loop and leave only body of it when it works fine.
importos#os.environ['TRITON_INTERPRET'] = '1'fromfunctoolsimportpartialimporttorchimporttritonimporttriton.languageastl@triton.jitdefbmm_dot_kernel(
k_size,
a_ptr, b_ptr, c_ptr,
stride_a0, stride_a1, stride_a2,
stride_b0, stride_b1, stride_b2,
stride_c0, stride_c1, stride_c2,
BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE_BATCH: tl.constexpr,
):
pid_m, pid_n, pid_batch=tl.program_id(0), tl.program_id(1), tl.program_id(2)
rm_vec=pid_m*BLOCK_SIZE+tl.arange(0, BLOCK_SIZE)
rn_vec=pid_n*BLOCK_SIZE+tl.arange(0, BLOCK_SIZE)
rk_vec=tl.arange(0, BLOCK_SIZE)
rbatch_vec_=pid_batch*BLOCK_SIZE_BATCH+tl.arange(0, BLOCK_SIZE_BATCH)
rbatch_vec3d=tl.expand_dims(tl.expand_dims(rbatch_vec_, 1), 1)
offs_a_mat=a_ptr+rbatch_vec3d*stride_a0+tl.expand_dims(rk_vec, 1)*stride_a1+tl.expand_dims(rn_vec, 0)*stride_a2offs_b_mat=b_ptr+rbatch_vec3d*stride_b0+tl.expand_dims(rk_vec, 1)*stride_b1+tl.expand_dims(rn_vec, 0)*stride_b2acc=tl.zeros((BLOCK_SIZE_BATCH, BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)
for_inrange(0, k_size, BLOCK_SIZE):
a=tl.load(offs_a_mat)
b=tl.load(offs_b_mat)
acc=tl.dot(a, b, acc)
offs_a_mat+=BLOCK_SIZE*stride_a2offs_b_mat+=BLOCK_SIZE*stride_b1c=c_ptr+rbatch_vec3d*stride_c0+tl.expand_dims(rm_vec, 1)*stride_c1+tl.expand_dims(rn_vec, 0)*stride_c2tl.store(c, acc)
defbmm_dot_op(a: torch.Tensor, b: torch.Tensor, matmul_k_fn):
batch_size=a.shape[0]
k_size=a.shape[2]
m=a.shape[1]
n=b.shape[2]
assertm==n# m and n are equal in this testBLOCK_SIZE=16# block sizes for m, n, k are all equalBLOCK_SIZE_BATCH=2# can be 1, 2, 4, 8, 16c=torch.zeros(batch_size, m, n, device=a.device, dtype=a.dtype)
grid= (1, 1, batch_size//BLOCK_SIZE_BATCH)
matmul_k_fn[grid](
k_size,
a, b, c,
a.stride(0), a.stride(1), a.stride(2),
b.stride(0), b.stride(1), b.stride(2),
c.stride(0), c.stride(1), c.stride(2),
BLOCK_SIZE=BLOCK_SIZE,
BLOCK_SIZE_BATCH=BLOCK_SIZE_BATCH
)
returncbmm_dot=partial(bmm_dot_op, matmul_k_fn=bmm_dot_kernel)
if__name__=='__main__':
dtype=torch.float16k=128# can be 16, 32, 128, 256... etc.mn=16batch_size=16a=torch.ones(batch_size, mn, k, dtype=dtype, device='cuda')
b=torch.ones(batch_size, k, mn, dtype=dtype, device='cuda')
c_ref=torch.bmm(a, b)
c_test=bmm_dot(a, b)
print(c_test)
allclose=torch.allclose(c_test, c_ref, rtol=0.1, atol=0.1)
assertallcloseprint(f'Test COMPLETE! {allclose=}')
Stacktrace from Triton 3.1 installed from pip repo:
CUDA_VISIBLE_DEVICES=4 python3 ./triton_kernels/bmm_dot.py
loc("/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py":35:20): error: offsets must have the same rank as input
Traceback (most recent call last):
File "/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py", line 83, in <module>
c_test = bmm_dot(a, b)
^^^^^^^^^^^^^
File "/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py", line 59, in bmm_dot_op
matmul_k_fn[grid](
File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/runtime/jit.py", line 662, in run
kernel = self.compile(
^^^^^^^^^^^^^
File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/compiler/compiler.py", line 282, in compile
next_module = compile_ir(module, metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 317, in <lambda>
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/pip_envs/cp312_torch251cu124_u2404/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 189, in make_ttgir
pm.run(mod)
RuntimeError: PassManager::run failed
Stacktrace from the latest Triton from the main branch, commit hash: 59bd2dc :
CUDA_VISIBLE_DEVICES=4 python3 ./triton_kernels/bmm_dot.py
/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py:35:20: error: offsets must have the same rank as input
a = tl.load(offs_a_mat)
^
/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py:35:20: note: see current operation: %81 = "ttg.memdesc_subview"(%76, %79, %80) : (!ttg.memdesc<2x16x16xf16, #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [2, 0, 1], hasLeadingOffset = false}>, #ttg.shared_memory, mutable, 2x2x16x16>, i32, i32) -> !ttg.memdesc<2x16x16xf16, #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [2, 0, 1], hasLeadingOffset = false}>, #ttg.shared_memory, mutable, 2x2x16x16>
Traceback (most recent call last):
File "/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py", line 83, in <module>
c_test = bmm_dot(a, b)
^^^^^^^^^^^^^
File "/workspace/src/my/cuda_projects/peer_cuda/./triton_kernels/bmm_dot.py", line 59, in bmm_dot_op
matmul_k_fn[grid](
File "/workspace/src/external/triton/python/triton/runtime/jit.py", line 336, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/src/external/triton/python/triton/runtime/jit.py", line 563, in run
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/src/external/triton/python/triton/compiler/compiler.py", line 281, in compile
next_module = compile_ir(module, metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/src/external/triton/python/triton/backends/nvidia/compiler.py", line 409, in <lambda>
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/src/external/triton/python/triton/backends/nvidia/compiler.py", line 270, in make_ttgir
pm.run(mod)
RuntimeError: PassManager::run failed
Environment details
nvidia-smi: Driver Version: 470.161.03 CUDA Version: 11.4 GPU: NVIDIA A100
The text was updated successfully, but these errors were encountered:
Describe the bug
tl.dot
with batched (3D) input is working only in emulation mode when os.environ['TRITON_INTERPRET'] = '1' is set.It may be issue with
tl.load
and not withtl.dot
.None of the possible values (1, 2, 4, 8, 16) of BLOCK_SIZE_BATCH are working:
If you remove the
for _ in range(0, k_size, BLOCK_SIZE)
loop and leave only body of it when it works fine.Problematic part:
If you move one of the
tl.load
statements outside the loop, when the error does not occur:Full code to reproduce:
Stacktrace from Triton 3.1 installed from pip repo:
Stacktrace from the latest Triton from the main branch, commit hash: 59bd2dc :
Environment details
nvidia-smi: Driver Version: 470.161.03 CUDA Version: 11.4
GPU: NVIDIA A100
The text was updated successfully, but these errors were encountered: