Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to load *ANY BASE MODEL* in 4bit #78

Open
ApoorvFrontera opened this issue Aug 20, 2024 · 1 comment
Open

Unable to load *ANY BASE MODEL* in 4bit #78

ApoorvFrontera opened this issue Aug 20, 2024 · 1 comment

Comments

@ApoorvFrontera
Copy link

ApoorvFrontera commented Aug 20, 2024

Hi VideoLLaMA Team,

I am facing issues while loading all the base models in 4-bit precision. The following lines try to load the mm_projector_weights which are stored in 16-bit precision into a model that requires the weights in 4bit leading to errors:

Code used for loading the models for inference

  model_path = 'DAMO-NLP-SG/VideoLLaMA2-8x7B-Base'
  model, processor, tokenizer = model_init(model_path, load_4bit=True)

Problematic part of the Code:
Lines: https://github.com/DAMO-NLP-SG/VideoLLaMA2/blob/main/videollama2/model/__init__.py#L171-L172

mm_projector_weights = load_mm_projector(model_path, token=token)
model.load_state_dict(mm_projector_weights, strict=False)

Error:

RuntimeError: Error(s) in loading state_dict for Videollama2MistralForCausalLM:
size mismatch for model.mm_projector.readout.0.weight: copying a param with shape torch.Size([4096, 4096]) from checkpoint, the shape in current model is torch.Size([8388608, 1]).
size mismatch for model.mm_projector.readout.2.weight: copying a param with shape torch.Size([4096, 4096]) from checkpoint, the shape in current model is torch.Size([8388608, 1]).

How can we use the 16-bit stored weights of the mm_projector_weights in 4-bit models?

@clownrat6
Copy link
Member

According to this document BitsAndBytesConfig, all of the linear layers will be replaced by FP4/NF4 layers if setting load_4bit. Therefore, it reports size mismatch error.

A temporary solution is to initialize a unquantified model, load projector weights, and save the whole model weights. The saved weights can be loaded successfully with load_4bit=True.

model, processor, tokenizer = model_init('DAMO-NLP-SG/VideoLLaMA2-7B-Base')
model.config.tune_mm_mlp_adapter = False
model.save_pretrained('VideoLLaMA2-7B-full')
tokenizer.save_pretrained('VideoLLaMA2-7B-full')

model, processor, tokenizer = model_init('VideoLLaMA2-7B-full', load_4bit=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants