Open
Description
Getting the following issue when running mosaic-bert recipe. Only with bf16, works with fp32.
Traceback (most recent call last):
File "<string>", line 21, in _bwd_kernel
KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-7929002797455b30efce6e41eddc6b57-3aa563e00c5c695dd945e23b09a86848-d962222789c30252d492a16cca3bf467-ff946bd4b3b4a4cbdf8cedc6e1c658e0-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.bfloat16, torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('matrix', False, 64, False, True, True, True, 128, 128), (True, True, True, True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/coc/scratch/sreddy65/examples/examples/bert/main.py", line 141, in <module>
main(cfg)
File "/coc/scratch/sreddy65/examples/examples/bert/main.py", line 128, in main
trainer.fit()
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 1787, in fit
self._train_loop()
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 1950, in _train_loop
total_loss_dict = self._train_batch(use_grad_scaling)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2126, in _train_batch
optimizer.step(closure=lambda **kwargs: self._train_microbatches(
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 68, in wrapper
return wrapped(*args, **kwargs)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/optim/optimizer.py", line 140, in wrapper
out = func(*args, **kwargs)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/optim/decoupled_weight_decay.py", line 289, in step
loss = closure()
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2126, in <lambda>
optimizer.step(closure=lambda **kwargs: self._train_microbatches(
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2209, in _train_microbatches
microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_batch_size, is_final_microbatch)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/composer/trainer/trainer.py", line 2305, in _train_microbatch
microbatch_loss.backward(create_graph=self._backwards_create_graph)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/torch/autograd/function.py", line 267, in apply
return user_fn(self, *args)
File "/nethome/sreddy65/examples/examples/bert/src/flash_attn_triton.py", line 1041, in backward
_flash_attn_backward(do,
File "/nethome/sreddy65/examples/examples/bert/src/flash_attn_triton.py", line 949, in _flash_attn_backward
_bwd_kernel[grid]( # type: ignore
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/jit.py", line 106, in launcher
return self.run(*args, grid=grid, **kwargs)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 73, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 73, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 63, in _bench
return do_bench(kernel_call)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/testing.py", line 136, in do_bench
fn()
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 62, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "/nethome/sreddy65/miniconda3/envs/mosaic/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 200, in run
return self.fn.run(*args, **kwargs)
File "<string>", line 43, in _bwd_kernel
RuntimeError: Triton Error [CUDA]: invalid argument
Metadata
Metadata
Assignees
Labels
No labels