Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Triton Error [CUDA]: invalid argument #237

Open
sameerreddy13 opened this issue Mar 14, 2023 · 17 comments
Open

RuntimeError: Triton Error [CUDA]: invalid argument #237

sameerreddy13 opened this issue Mar 14, 2023 · 17 comments

Comments

@sameerreddy13
Copy link

Getting the following issue when running mosaic-bert recipe. Only with bf16, works with fp32.

Traceback (most recent call last):
  File "<string>", line 21, in _bwd_kernel
KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-7929002797455b30efce6e41eddc6b57-3aa563e00c5c695dd945e23b09a86848-d962222789c30252d492a16cca3bf467-ff946bd4b3b4a4cbdf8cedc6e1c658e0-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.bfloat16, torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('matrix', False, 64, False, True, True, True, 128, 128), (True, True, True, True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False)))
During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/coc/scratch/sreddy65/examples/examples/bert/main.py", line 141, in <module>
    main(cfg)
  File "/coc/scratch/sreddy65/examples/examples/bert/main.py", line 128, in main
    trainer.fit()
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 1787, in fit
    self._train_loop()
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 1950, in _train_loop
    total_loss_dict = self._train_batch(use_grad_scaling)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2126, in _train_batch
    optimizer.step(closure=lambda **kwargs: self._train_microbatches(
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 68, in wrapper
    return wrapped(*args, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/optim/optimizer.py", line 140, in wrapper
    out = func(*args, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/optim/decoupled_weight_decay.py", line 289, in step
    loss = closure()
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2126, in <lambda>
    optimizer.step(closure=lambda **kwargs: self._train_microbatches(
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2209, in _train_microbatches
    microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2305, in _train_microbatch
    microbatch_loss.backward(create_graph=self._backwards_create_graph)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/autograd/function.py", line 267, in apply
    return user_fn(self, *args)
  File "/nethome/sreddy65/examples/examples/bert/src/flash_attn_triton.py", line 1041, in backward
    _flash_attn_backward(do,
  File "/nethome/sreddy65/examples/examples/bert/src/flash_attn_triton.py", line 949, in _flash_attn_backward
    _bwd_kernel[grid](  # type: ignore
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 73, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 73, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 63, in _bench
    return do_bench(kernel_call)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/testing.py", line 136, in do_bench
    fn()
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 62, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 200, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 43, in _bwd_kernel
RuntimeError: Triton Error [CUDA]: invalid argument
@alextrott16
Copy link
Contributor

Hi @sameerreddy13, thanks for bringing this to our attention!

Can you help us understand how your environment is set up? We're wondering if somehow the wrong triton version is getting used.

@sameerreddy13
Copy link
Author

Yep! I can't use docker in my env but other than that this is the environment.

8x Nvidia A40

  Operating System: Ubuntu 18.04.6 LTS
            Kernel: Linux 4.15.0-204-generic
      Architecture: x86-64

Pytorch:

pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0
pytorch-cuda=11.7=h67b0de4_1

Python:

Python 3.10.9

Conda Env:
conda_env_mosaic.txt

@samhavens
Copy link
Contributor

@sameerreddy13 could you see if this happens when using python 3.9.x? We haven't tested 3.10 much

@sameerreddy13
Copy link
Author

sameerreddy13 commented Mar 18, 2023

Still happens with python 3.9.16

@sameerreddy13
Copy link
Author

sameerreddy13 commented Mar 19, 2023

So I upgraded to torch2.0 and dropped in the new flash attention module and it works without a hitch. It might be simpler to switch to this for BERT atleast. I can make a PR for this if this is wanted. This was after spending a while trying to debug the kernel issue.

@eldarkurtic
Copy link

@sameerreddy13 I am having the same issue, would you be able to share the diffs for "dropped in the new flash attention module and it works without a hitch"?

@sameerreddy13
Copy link
Author

sameerreddy13 commented Apr 12, 2023

hey @eldarkurtic here is the main diff. You can drop the if condition (I put the or True and forgot to remove it on my fork)

diff --git a/examples/bert/src/bert_layers.py b/examples/bert/src/bert_layers.py
index 4f8403c..db68ba8 100644
--- a/examples/bert/src/bert_layers.py
+++ b/examples/bert/src/bert_layers.py
@@ -209,18 +209,22 @@ class BertUnpadSelfAttention(nn.Module):
                         'b s (t h d) -> b s t h d',
                         t=3,
                         h=self.num_attention_heads)
-        if self.p_dropout or flash_attn_qkvpacked_func is None:
+        # NOTE: FLASH ATTENTION
+        if self.p_dropout or flash_attn_qkvpacked_func is None or True:
             # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
             q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3)  # b h s d
-            k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1)  # b h d s
+            # k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1)  # b h d s
+            k = qkv[:, :, 1, :, :].permute(0, 2, 1, 3)  # b h s d
             v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3)  # b h s d
-            attention_scores = torch.matmul(q, k) / math.sqrt(
-                self.attention_head_size)
-            attention_scores = attention_scores + bias
-            attention_probs = nn.functional.softmax(attention_scores, dim=-1)
-            attention_probs = self.dropout(attention_probs)
-            attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
-                                                                 3)  # b s h d
+            # attention_scores = torch.matmul(q, k) / math.sqrt(
+            #     self.attention_head_size)
+            # attention_scores = attention_scores + bias
+            # attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+            # attention_probs = self.dropout(attention_probs)
+            # attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3)
+            # ALWAYS RUNS
+            attention = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.p_dropout, is_causal=False)
+            attention = attention.permute(0, 2, 1, 3)
         else:
             # Triton implementation only supports 0 attention dropout
             convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]

and a file to just test if flash attention is avaliable in your env

diff --git a/examples/bert/test_pytorch2.0_attention.py b/examples/bert/test_pytorch2.0_attention.py
new file mode 100644
index 0000000..b30140d
--- /dev/null
+++ b/examples/bert/test_pytorch2.0_attention.py
@@ -0,0 +1,51 @@
+# Lets define a helpful benchmarking function:
+import torch.utils.benchmark as benchmark
+import torch.nn.functional as F
+import torch
+def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
+    t0 = benchmark.Timer(
+        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
+    )
+    return t0.blocked_autorange().mean * 1e6
+
+# Lets define the hyper-parameters of our input
+device = 'cuda'
+batch_size = 32
+max_sequence_len = 1024
+num_heads = 32
+embed_dimension = 32
+
+dtype = torch.float16
+
+query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
+key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
+value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
+
+print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
+
+# Lets explore the speed of each of the 3 implementations
+from torch.backends.cuda import sdp_kernel, SDPBackend
+
+# Helpful arg mapper
+backend_map = {
+    SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
+    SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
+    SDPBackend.EFFICIENT_ATTENTION: {
+        "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
+}
+
+with sdp_kernel(**backend_map[SDPBackend.MATH]):
+    print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
+
+
+with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
+    try:
+        print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
+    except RuntimeError:
+        print("FlashAttention is not supported. See warnings for reasons.")
+
+with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
+    try:
+        print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
+    except RuntimeError:
+        print("EfficientAttention is not supported. See warnings for reasons.")

@eldarkurtic
Copy link

Thanks a lot @sameerreddy13 !

@godfrey-cw
Copy link

godfrey-cw commented Apr 23, 2023

In case it's useful to anyone, I'm strangely getting this error when running the glue test script python glue.py yamls/test/glue.yaml model.name=mosaic_bert && rm -rf local-finetune-checkpoints but not the pretraining test script composer main.py yamls/test/main.yaml model.name=mosaic_bert. My env is essentially miniconda's base (python v3.9.12) plus pip install -e ".[bert]".

FWIW I'm on tesla gpus since that's what I can access at this moment.

EDIT: I believe I'm getting this w/ fp32. In /examples/bert/yamls/test/glue.yaml I have the block:

# Basic run configuration, additional details will be added to this name for each GLUE task, and each random seed
base_run_name: glue-finetuning-benchmark-test
default_seed: 1111
precision: fp32

@godfrey-cw
Copy link

Update: on a slightly different environment (listed below, needed to make some tweaks to build apex on my system) I'm getting this issue using composer main.py yamls/main/mosaic-bert-base-uncased.yaml as well, with fp32 precision on A100s. However, when I switch to bf16 everything seems to be working. No idea why for me bf16 vs fp32 is seemingly the opposite of what @sameerreddy13 is experiencing!

My hunch is that this issue is related to triton-lang/triton#1512.

relevant pieces of environment for this post

torch: 1.12.1, compiled for CUDA 11.3
CUDA 11.4 (I know, weird given the above but I'm not sure that's the issue ...)
triton: the one listed in your requirements.txt
apex: github main branch
GPUs: A100 (same issue w/ 40 or 80 GB)

@mitchellnw
Copy link

mitchellnw commented Apr 26, 2023

Yep looks like a similar error to what I'm seeing.. still haven't been able to resolve mine if anybody has any advice would be much appreciated (triton-lang/triton#1512)

@dakinggg
Copy link
Collaborator

I have a feeling this is some combination of cuda/triton/torch versions...but is not something we have encountered at all. The MosaicBERT work was done mostly on torch 1.12.1+cu116 I believe. And we've since run the code with torch1.13.1+cu117. @mitchellnw What does your environment look like?

@mitchellnw
Copy link

Thanks yea you're probably right. I'm on torch2.0.0+cu118 with triton2.0.0. I'll try torch1.13.1+cu117 and see if that works.

@dakinggg
Copy link
Collaborator

We've also been using triton 2.0.0.dev20221202 for gpt stuff and triton 2.0.0.dev20221103 for the bert stuff. no particular reason they are different, just the way it happened. So I think the older triton version has been used with both torch 1.12 and 1.13, and the newer triton version has been used with 1.12.

@mitchellnw
Copy link

thanks, really appreciate it! i'll mess around with versions (probably later this week) and see if that fixes things

@godfrey-cw
Copy link

I actually tried switching my triton version to 2.0.0 (with everything else in the environment I listed above the same if I recall correctly) and got a completely different error ('invalid source', essentially the same error appearing at triton-lang/triton#1098). My guess was that the syntax of triton's dot function has changed between 2.0.0.dev20221103 and 2.0.0?

@malteos
Copy link

malteos commented Oct 5, 2023

Any news on this issue?

The error does not occur when using fp32 precision. But how to fix this for bf16?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants