-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA MoE with k_proj, up_proj, down_proj #13
Comments
Hi, Thanks for opening the issue. Can you please provide more details on your implementation, and experts used for merging? |
The experts are LoRA finetuned Llama 7B models with target modules ['down_proj', 'v_proj', 'up_proj', 'q_proj', 'k_proj'] and this was the error encontered |
@aksh555, could you share the scripts for replicating the model composition and the forward pass? |
I met the same bug. Looking forward a debuger :) The code used to merge llama3-lora: """
Replaces ff layers using MOE. rest all will be averaged
"""
import torch
from mergoo.compose_experts import ComposeExperts
from mergoo.models.modeling_llama import LlamaForCausalLM
model_id = "moe_model/llama3_lora_moe"
config = {
"model_type": "llama",
"num_experts_per_tok": 2,
"base_model": "../Meta-Llama-3-8B-Instruct",
"experts": [
{"expert_name": "adapter_1", "model_id": "llama3_lora_1"},
{"expert_name": "adapter_2", "model_id": "llama3_lora_2"},
],
}
# create checkpoint
import os
if not os.path.exists(model_id):
expertcomposer = ComposeExperts(config)
expertcomposer.compose()
expertcomposer.save_checkpoint(model_id)
# load the composed checkkpoint
model = LlamaForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto"
) # 'gate' / router layers are untrained hence loaded warning would appeare for them
out = model(torch.tensor([[1, 2, 3, 33, 44]], device=model.device))
print("done") The error messages: The overall error information: Traceback (most recent call last):
File "/data/wentao/slz/mergoo/examples/compose_lora_mistral.py", line 35, in <module>
out = model(torch.tensor([[1, 2, 3, 33, 44]], device=model.device))
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/models/modeling_llama.py", line 1177, in forward
outputs = self.model(
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/models/modeling_llama.py", line 1020, in forward
layer_outputs = decoder_layer(
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/models/modeling_llama.py", line 756, in forward
hidden_states = self.mlp(hidden_states)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/models/modeling_llama.py", line 242, in forward
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/mergoo/compose_layers.py", line 126, in forward
gate_logits = self.gate(x) # b,s,N
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wentao/anaconda3/envs/qwen_moe/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (5x14336 and 4096x2) |
Hi @Aurora-slz, Could you please share the adapter configurations for |
Hi, thanks for the library!
When we try to compose LoRA experts that have k_proj, up_proj, down_proj in the target_modules, we face a shape mismatch error. Everything works fine when the target modules are only q_proj and v_proj. Any suggestions on how to fix this?
The text was updated successfully, but these errors were encountered: