Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training error #9

Open
Mi5sssss opened this issue Jun 18, 2024 · 2 comments
Open

Training error #9

Mi5sssss opened this issue Jun 18, 2024 · 2 comments

Comments

@Mi5sssss
Copy link

Is it possible to have training script example? I encountered tensor mismatch when i increase the training batch more than 1 as following (batch size 4):
Original Traceback (most recent call last): File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker output = module(*input, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1164, in forward outputs = self.model( ^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 968, in forward layer_outputs = decoder_layer( ^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/new_mod/Mixture-of-depths/MoD/MoD.py", line 70, in forward block_output = self.block( ^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 713, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( ^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 649, in forward attn_output = torch.nn.functional.scaled_dot_product_attention( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: The expanded size of the tensor (1022) must match the existing size (511) at non-singleton dimension 3. Target sizes: [1, 32, 511, 1022]. Tensor sizes: [1, 1, 511, 511]

If i keep training batch as 1, i have some tuple index error:
Original Traceback (most recent call last): File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker output = module(*input, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1164, in forward outputs = self.model( ^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 981, in forward next_decoder_cache = layer_outputs[2 if output_attentions else 1] ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ IndexError: tuple index out of range

@JAVI897
Copy link

JAVI897 commented Jul 16, 2024

Have you been able to solve this? I am encountering the same error

@pharaohcaptain
Copy link

pharaohcaptain commented Jul 31, 2024

The issue of "tuple index out of range" may arise because the embedding multiplied by the router's weights can lead to some values exceeding the representational range of float16, resulting in inf.
A direct solution is to normalize the weights:

def forward(self, x):
    original_type = x.dtype
    self.weight_predictor.to(torch.float32)
    weights = self.weight_predictor(x.to(self.weight_predictor.weight.dtype)).squeeze(
        -1
    )  # [batch_size, seq_len]
    weights = weights / torch.sum(weights,dim=-1,keepdim=True)
    return weights.to(original_type)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants