-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
UnboundLocalError: local variable 'out' referenced before assignment #1412
Comments
+1 System InformationOS: Ubuntu 22.04 Pip list output
Error output
Attempted Solutions
|
It seems to related with this issue #1390 . |
Here is my test code import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn.attention import SDPBackend, sdpa_kernel
import flash_attn
from flash_attn import flash_attn_interface
class AttnBase(nn.Module):
def __init__(self, d_model, n_head=4):
super().__init__()
self.n_head = n_head
self.d_head = d_model // self.n_head
self.q_proj = nn.Linear(d_model, self.n_head*self.d_head)
self.k_proj = nn.Linear(d_model, self.n_head*self.d_head)
self.v_proj = nn.Linear(d_model, self.n_head*self.d_head)
def attn_impl(self, q, k, v):
raise NotImplementedError
def forward(self, hidden_states):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).unflatten(-1, (self.n_head, self.d_head))
key_states = self.k_proj(hidden_states).unflatten(-1, (self.n_head, self.d_head))
value_states = self.v_proj(hidden_states).unflatten(-1, (self.n_head, self.d_head))
ret = self.attn_impl(query_states, key_states, value_states)
assert ret.shape[:2] == (bsz, q_len), f"output shape {ret.shape} mismatch with input {hidden_states.shape}"
return ret
class SDPAAttnBlock(AttnBase):
def __init__(self, *args, backend=SDPBackend.MATH, **kwargs):
super().__init__(*args, **kwargs)
self.backend = backend
def attn_impl(self, q, k, v):
with sdpa_kernel(self.backend):
ret = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).flatten(2)
return ret
class FlashAttnBlock(AttnBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def attn_impl(self, q, k, v):
# print("attn_impl", torch.is_grad_enabled(), [x.requires_grad for x in [q, k, v]])
assert torch.is_grad_enabled() is True
assert any(x.requires_grad for x in [q, k, v]) is True
ret = flash_attn_interface.flash_attn_func(q, k, v).flatten(2)
return ret
def main(args):
print("==== version ===")
print("torch version =", torch.__version__)
print("CUDA version =", torch.version.cuda)
print("flash_attn version =", flash_attn.__version__)
device = torch.device("cuda")
D_MODEL = 128
N_HEAD = 4
BS = 1
LEN = 16
base_model = nn.ModuleList([AttnBase(D_MODEL, N_HEAD) for _ in range(4)]).to(device)
flash_model = nn.ModuleList([FlashAttnBlock(D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()
sdpa_modelm = nn.ModuleList([SDPAAttnBlock( D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()
sdpa_modelf = nn.ModuleList([SDPAAttnBlock( D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()
sdpa_modele = nn.ModuleList([SDPAAttnBlock( D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()
sdpa_modelc = nn.ModuleList([SDPAAttnBlock( D_MODEL, N_HEAD) for _ in range(4)]).bfloat16().to(device).train()
sdpa_modelm.backend = SDPBackend.MATH
sdpa_modelf.backend = SDPBackend.FLASH_ATTENTION
sdpa_modele.backend = SDPBackend.EFFICIENT_ATTENTION
sdpa_modelc.backend = SDPBackend.CUDNN_ATTENTION
test_dict = {
"sdpa_math": sdpa_modelm,
"sdpa_flash": sdpa_modelf,
"sdpa_efficient": sdpa_modele,
"sdpa_cudnn": sdpa_modelc,
"flash": flash_model,
}
for k, v in test_dict.items():
v.load_state_dict(base_model.state_dict())
data = torch.randn(BS, LEN, D_MODEL, dtype=torch.bfloat16, device=device).requires_grad_(True)
@torch.enable_grad()
def test_fn(use_checkpoint=False, **checkpoint_kwargs):
for name, model in test_dict.items():
try:
x = data
if use_checkpoint:
for blk in model:
x = torch.utils.checkpoint.checkpoint(blk, x, **checkpoint_kwargs)
else:
for blk in model:
x = blk(x)
x.mean().backward()
print(f"[{name}] OK")
except Exception as ex:
import traceback
print(traceback.format_exc())
print(f"[{name}] Fail")
print("===checkpoint=True reentrant=True")
test_fn(use_checkpoint=True, use_reentrant=True)
print("===checkpoint=True reentrant=False")
test_fn(use_checkpoint=True, use_reentrant=False)
print("===checkpoint=False")
test_fn(use_checkpoint=False)
if __name__ == "__main__":
main(None) |
[rank0]: Traceback (most recent call last):
[rank0]: File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]: return _run_code(code, main_globals, None,
[rank0]: File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/openrlhf/cli/train_sft.py", line 249, in
[rank0]: train(args)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/openrlhf/cli/train_sft.py", line 132, in train
[rank0]: trainer.fit(args, consumed_samples, num_update_steps_per_epoch)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/openrlhf/trainer/sft_trainer.py", line 141, in fit
[rank0]: output = self.model(inputs, attention_mask=attention_mask, return_output=True)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/openrlhf/models/actor.py", line 208, in forward
[rank0]: output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 1899, in forward
[rank0]: loss = self.module(*inputs, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 1164, in forward
[rank0]: outputs = self.model(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 883, in forward
[rank0]: layer_outputs = self._gradient_checkpointing_func(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
[rank0]: return disable_fn(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 496, in checkpoint
[rank0]: ret = function(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward
[rank0]: hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 443, in forward
[rank0]: attn_output = _flash_attention_forward(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_flash_attention_utils.py", line 297, in _flash_attention_forward
[rank0]: attn_output = flash_attn_func(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.7.2.post1-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 1189, in flash_attn_func
[rank0]: return FlashAttnFunc.apply(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply
[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank0]: File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.7.2.post1-py3.10-linux-x86_64.egg/flash_attn/flash_attn_interface.py", line 853, in forward
[rank0]: return out if not return_softmax else (out, softmax_lse, S_dmask)
[rank0]: UnboundLocalError: local variable 'out' referenced before assignment
is this a flash_attn version problem? thank you!
root@e4b47fc2098b:/workspace# nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0
root@e4b47fc2098b:/workspace# pip show torch
Name: torch
Version: 2.5.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3-Clause
Location: /usr/local/lib/python3.10/dist-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by: accelerate, bitsandbytes, compressed-tensors, deepspeed, flash-attn, lightning-thunder, openrlhf, optimum, peft, torch-tensorrt, torchmetrics, torchvision, vllm, xformers
root@e4b47fc2098b:/workspace# python -c "import torch; print(torch.version); print(torch.cuda.is_available()); print(torch.version.cuda)"
2.5.1+cu124
True
12.4
root@e4b47fc2098b:/workspace# python -c "from flash_attn.utils.distributed import all_gather"
root@e4b47fc2098b:/workspace#
The text was updated successfully, but these errors were encountered: