Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UnboundLocalError: local variable 'out' referenced before assignment #1412

Open
chuangzhidan opened this issue Dec 24, 2024 · 3 comments
Open

Comments

@chuangzhidan
Copy link

[rank0]: Traceback (most recent call last):
[rank0]: File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]: return _run_code(code, main_globals, None,
[rank0]: File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/openrlhf/cli/train_sft.py", line 249, in
[rank0]: train(args)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/openrlhf/cli/train_sft.py", line 132, in train
[rank0]: trainer.fit(args, consumed_samples, num_update_steps_per_epoch)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/openrlhf/trainer/sft_trainer.py", line 141, in fit
[rank0]: output = self.model(inputs, attention_mask=attention_mask, return_output=True)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/openrlhf/models/actor.py", line 208, in forward
[rank0]: output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 1899, in forward
[rank0]: loss = self.module(*inputs, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 1164, in forward
[rank0]: outputs = self.model(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 883, in forward
[rank0]: layer_outputs = self._gradient_checkpointing_func(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
[rank0]: return disable_fn(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 496, in checkpoint
[rank0]: ret = function(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward
[rank0]: hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 443, in forward
[rank0]: attn_output = _flash_attention_forward(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_flash_attention_utils.py", line 297, in _flash_attention_forward
[rank0]: attn_output = flash_attn_func(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.7.2.post1-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 1189, in flash_attn_func
[rank0]: return FlashAttnFunc.apply(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply

[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank0]: File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.7.2.post1-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 853, in forward
[rank0]: return out if not return_softmax else (out, softmax_lse, S_dmask)
[rank0]: UnboundLocalError: local variable 'out' referenced before assignment

is this a flash_attn version problem? thank you!

root@e4b47fc2098b:/workspace# nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0

root@e4b47fc2098b:/workspace# pip show torch
Name: torch
Version: 2.5.1

Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3-Clause
Location: /usr/local/lib/python3.10/dist-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by: accelerate, bitsandbytes, compressed-tensors, deepspeed, flash-attn, lightning-thunder, openrlhf, optimum, peft, torch-tensorrt, torchmetrics, torchvision, vllm, xformers

root@e4b47fc2098b:/workspace# python -c "import torch; print(torch.version); print(torch.cuda.is_available()); print(torch.version.cuda)"
2.5.1+cu124
True
12.4

root@e4b47fc2098b:/workspace# python -c "from flash_attn.utils.distributed import all_gather"
root@e4b47fc2098b:/workspace#

@tvosch
Copy link

tvosch commented Dec 25, 2024

+1
Experiencing same problem HuggingFace's Transformers Trainer with Flash Attention 2.7.2.post1 on single GPU pre-training with Llama 3.1.

System Information

OS: Ubuntu 22.04
GPU: AMD MI250X ROCM 6.3
Python: 3.10

Pip list output

torch: 2.4.0a0+gitb4b81bd (latest ROCM PyTorch container)
transformers: 4.47.1
accelerate: 1.2.1
flash_attn: 2.7.2.post1 (compiled from source)

Error output

0:     attn_output = _flash_attention_forward(
0:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/modeling_flash_attention_utils.py", line 363, in _flash_attention_forward
0:     attn_output = flash_attn_func(
0:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/flash_attn-2.7.2.post1-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 1189, in flash_attn_func
0:     return FlashAttnFunc.apply(
0:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
0:     return super().apply(*args, **kwargs)  # type: ignore[misc]
0:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/flash_attn-2.7.2.post1-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 853, in forward
0:     return out if not return_softmax else (out, softmax_lse, S_dmask)
0: UnboundLocalError: local variable 'out' referenced before assignment

Attempted Solutions

  • Turn off Flash Attention 2 in AutoModelForCausalLM works fine
  • ... will update with different tries in removing packed/constantlength dataset

@liguohao96
Copy link

It seems to related with this issue #1390 .
my personal suggest is that do not use the latest version.
The version with is_grad in flash_attn_interface.py introduce a lot of potential bugs. (ver 2.7.2 seems OK)
As far as I know, the latest version will cause the backward failed with pytorch 2.5.1+cu118, due to torch.is_grad_enabled() seems to always return False in torch.autograd.Function (related to this issue of pytorch pytorch/pytorch#56370)

@liguohao96
Copy link

Here is my test code

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn.attention import SDPBackend, sdpa_kernel

import flash_attn
from flash_attn import flash_attn_interface

class AttnBase(nn.Module):
    def __init__(self, d_model, n_head=4):
        super().__init__()

        self.n_head = n_head
        self.d_head = d_model // self.n_head

        self.q_proj = nn.Linear(d_model, self.n_head*self.d_head)
        self.k_proj = nn.Linear(d_model, self.n_head*self.d_head)
        self.v_proj = nn.Linear(d_model, self.n_head*self.d_head)

    def attn_impl(self, q, k, v):
        raise NotImplementedError
    
    def forward(self, hidden_states):
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states).unflatten(-1, (self.n_head, self.d_head))
        key_states   = self.k_proj(hidden_states).unflatten(-1, (self.n_head, self.d_head))
        value_states = self.v_proj(hidden_states).unflatten(-1, (self.n_head, self.d_head))

        ret = self.attn_impl(query_states, key_states, value_states)

        assert ret.shape[:2] == (bsz, q_len), f"output shape {ret.shape} mismatch with input {hidden_states.shape}"
        return ret

class SDPAAttnBlock(AttnBase):
    def __init__(self, *args, backend=SDPBackend.MATH, **kwargs):
        super().__init__(*args, **kwargs)

        self.backend = backend

    def attn_impl(self, q, k, v):
        with sdpa_kernel(self.backend):
            ret = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).flatten(2)
        return ret

class FlashAttnBlock(AttnBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def attn_impl(self, q, k, v):
        # print("attn_impl", torch.is_grad_enabled(), [x.requires_grad for x in [q, k, v]])
        assert torch.is_grad_enabled() is True
        assert any(x.requires_grad for x in [q, k, v]) is True

        ret = flash_attn_interface.flash_attn_func(q, k, v).flatten(2)
        return ret

def main(args):
    print("==== version ===")
    print("torch version =", torch.__version__)
    print("CUDA  version =", torch.version.cuda)
    print("flash_attn version =", flash_attn.__version__)

    device  = torch.device("cuda")

    D_MODEL = 128
    N_HEAD  = 4
    BS      = 1
    LEN     = 16

    base_model = nn.ModuleList([AttnBase(D_MODEL, N_HEAD) for _ in range(4)]).to(device)

    flash_model = nn.ModuleList([FlashAttnBlock(D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()
    sdpa_modelm = nn.ModuleList([SDPAAttnBlock( D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()
    sdpa_modelf = nn.ModuleList([SDPAAttnBlock( D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()
    sdpa_modele = nn.ModuleList([SDPAAttnBlock( D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()
    sdpa_modelc = nn.ModuleList([SDPAAttnBlock( D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()

    sdpa_modelm.backend = SDPBackend.MATH
    sdpa_modelf.backend = SDPBackend.FLASH_ATTENTION
    sdpa_modele.backend = SDPBackend.EFFICIENT_ATTENTION
    sdpa_modelc.backend = SDPBackend.CUDNN_ATTENTION

    test_dict = {
        "sdpa_math":       sdpa_modelm,
        "sdpa_flash":      sdpa_modelf,
        "sdpa_efficient":  sdpa_modele,
        "sdpa_cudnn":      sdpa_modelc,
        "flash":           flash_model,
    }

    for k, v in test_dict.items():
        v.load_state_dict(base_model.state_dict())

    data = torch.randn(BS, LEN, D_MODEL, dtype=torch.bfloat16, device=device).requires_grad_(True)

    @torch.enable_grad()
    def test_fn(use_checkpoint=False, **checkpoint_kwargs):
        for name, model in test_dict.items():
            try:
                x = data
                if use_checkpoint:
                    for blk in model:
                        x = torch.utils.checkpoint.checkpoint(blk, x, **checkpoint_kwargs)
                else:
                    for blk in model:
                        x = blk(x)

                x.mean().backward()

                print(f"[{name}] OK")

            except Exception as ex:
                import traceback
                print(traceback.format_exc())

                print(f"[{name}] Fail")
    
    print("===checkpoint=True reentrant=True")
    test_fn(use_checkpoint=True, use_reentrant=True)
    print("===checkpoint=True reentrant=False")
    test_fn(use_checkpoint=True, use_reentrant=False)

    print("===checkpoint=False")
    test_fn(use_checkpoint=False)


if __name__ == "__main__":
    main(None)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants