Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Softmax tutorial crashes (invalid arith.select) when n_cols is a multiple of 16 but <= 128 #4739

Open
akeley98 opened this issue Sep 17, 2024 · 2 comments

Comments

@akeley98
Copy link

akeley98 commented Sep 17, 2024

triton.__version__ is 3.0.0 for me

The tutorial code 02-fused-softmax.py given in https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html fails to compile a kernel during the warmup when n_col is a multiple of 16 that is less than or equal to 128 (i.e., <= 16 * num_warps). Error looks like:

(triton) mantissa@MantissaAmpere:~/junk$ python3 ../Downloads/02-fused-softmax.py 
loc("/home/mantissa/junk/../Downloads/02-fused-softmax.py":97:22): error: 'arith.select' op expected condition type to have the same shape as the result type, expected 'tensor<128xi1, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>>', but got 'tensor<128xi1, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>>'
Traceback (most recent call last):
  File "/home/mantissa/junk/../Downloads/02-fused-softmax.py", line 196, in <module>
    y_triton = softmax(x)
  File "/home/mantissa/junk/../Downloads/02-fused-softmax.py", line 144, in softmax
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/runtime/jit.py", line 764, in warmup
    return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/compiler/compiler.py", line 282, in compile
    next_module = compile_ir(module, metadata)
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 317, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
  File "/home/mantissa/junk/triton/lib/python3.10/site-packages/triton/backends/nvidia/compiler.py", line 189, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

Repro: modify line 195 in 02-fused-softmax.py from

x = torch.randn(1823, 781, device='cuda')

to

x = torch.randn(1823, 80, device='cuda')

and run.

@akeley98
Copy link
Author

akeley98 commented Sep 17, 2024

This seems broken ... the tutorial builds a simple kernel cache using softmax_kernel.warmup where the caching depends only on BLOCK_SIZE. However it seems that warmup actually specializes based on the kernel arguments themselves (not just meta-parameters) and this is where the compiler gets in trouble ... it seems to give a different kernel entirely when y.stride(0) is divisible by 16. So the kernel that's cached depends on the parameters to the first run of the softmax() function (since that determines how the kernel is compiled). So if I write a function

def do_it(n_cols):
    torch.manual_seed(0)
    x = torch.randn(10000, n_cols, device='cuda')
    y_triton = softmax(x)
    y_torch = torch.softmax(x, axis=1)
    assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

then run

do_it(n_cols = 80)

then this crashes like before, but if I do

do_it(n_cols = 79)
do_it(n_cols = 80)

then it works OK, as the kernel correctly compiled for the n_cols = 79 case is successfully re-run for the n_cols = 80 case.

By the way, this also means I can crash the example by running

do_it(n_cols = 512)
do_it(n_cols = 511)

because the n_cols = 512 call specializes the cached kernel for aligned data, which doesn't work correctly for n_cols = 511. I'm seeing

Traceback (most recent call last):
  File "/home/mantissa/junk/tl_softmax.py", line 157, in <module>
    do_it(n_cols = 511)
  File "/home/mantissa/junk/tl_softmax.py", line 154, in do_it
    assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
RuntimeError: CUDA error: misaligned address
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@akeley98
Copy link
Author

I don't know what is going on with the original bug, but I did some investigation on the second issue I found, and it looks like there already is a cache implemented for JITFunction inside runtime/jit.py that specializes based on whether input tensors are aligned or not. This seems to be implemented with compute_spec_key (checks 16-byte alignment) and the sig_and_spec portion of the key used to index self.cache. Since there's a built-in cache, I'm not sure why the tutorial implements its own (broken) caching system.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant