Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

size mismatch between model and ckpt #162

Closed
jiachengc opened this issue Sep 11, 2024 · 2 comments
Closed

size mismatch between model and ckpt #162

jiachengc opened this issue Sep 11, 2024 · 2 comments

Comments

@jiachengc
Copy link

Thanks for this solid work. I have a question during finetuning the 'esc50' dataset. I try to run fintune-esc50.sh but got error below:

File "/data/jiacheng/CLAP/src/laion_clap/clap_module/factory.py", line 155, in create_model model.load_state_dict(ckpt) File "/home/jiacheng/anaconda3/envs/clap/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for CLAP: Missing key(s) in state_dict: "audio_branch.conv_block1.conv1.weight", "audio_branch.conv_block1.conv2.weight", "audio_branch.conv_block1.bn1.weight", "audio_branch.conv_block1.bn1.bias", "audio_branch.conv_block1.bn1.running_mean", "audio_branch.conv_block1.bn1.running_var", "audio_branch.conv_block1.bn2.weight", "audio_branch.conv_block1.bn2.bias", "audio_branch.conv_block1.bn2.running_mean", "audio_branch.conv_block1.bn2.running_var", "audio_branch.conv_block2.conv1.weight", "audio_branch.conv_block2.conv2.weight", "audio_branch.conv_block2.bn1.weight", "audio_branch.conv_block2.bn1.bias", "audio_branch.conv_block2.bn1.running_mean", "audio_branch.conv_block2.bn1.running_var", "audio_branch.conv_block2.bn2.weight", "audio_branch.conv_block2.bn2.bias", "audio_branch.conv_block2.bn2.running_mean", "audio_branch.conv_block2.bn2.running_var", "audio_branch.conv_block3.conv1.weight", "audio_branch.conv_block3.conv2.weight", "audio_branch.conv_block3.bn1.weight", "audio_branch.conv_block3.bn1.bias", "audio_branch.conv_block3.bn1.running_mean", "audio_branch.conv_block3.bn1.running_var", "audio_branch.conv_block3.bn2.weight", "audio_branch.conv_block3.bn2.bias", "audio_branch.conv_block3.bn2.running_mean", "audio_branch.conv_block3.bn2.running_var", "audio_branch.conv_block4.conv1.weight", "audio_branch.conv_block4.conv2.weight", "audio_branch.conv_block4.bn1.weight", "audio_branch.conv_block4.bn1.bias", "audio_branch.conv_block4.bn1.running_mean", "audio_branch.conv_block4.bn1.running_var", "audio_branch.conv_block4.bn2.weight", "audio_branch.conv_block4.bn2.bias", "audio_branch.conv_block4.bn2.running_mean", "audio_branch.conv_block4.bn2.running_var", "audio_branch.conv_block5.conv1.weight", "audio_branch.conv_block5.conv2.weight", "audio_branch.conv_block5.bn1.weight", "audio_branch.conv_block5.bn1.bias", "audio_branch.conv_block5.bn1.running_mean", "audio_branch.conv_block5.bn1.running_var", "audio_branch.conv_block5.bn2.weight", "audio_branch.conv_block5.bn2.bias", "audio_branch.conv_block5.bn2.running_mean", "audio_branch.conv_block5.bn2.running_var", "audio_branch.conv_block6.conv1.weight", "audio_branch.conv_block6.conv2.weight", "audio_branch.conv_block6.bn1.weight", "audio_branch.conv_block6.bn1.bias", "audio_branch.conv_block6.bn1.running_mean", "audio_branch.conv_block6.bn1.running_var", "audio_branch.conv_block6.bn2.weight", "audio_branch.conv_block6.bn2.bias", "audio_branch.conv_block6.bn2.running_mean", "audio_branch.conv_block6.bn2.running_var", "audio_branch.fc1.weight", "audio_branch.fc1.bias", "audio_branch.fc_audioset.weight", "audio_branch.fc_audioset.bias". Unexpected key(s) in state_dict: "audio_branch.patch_embed.proj.weight", "audio_branch.patch_embed.proj.bias", "audio_branch.patch_embed.norm.weight", "audio_branch.patch_embed.norm.bias", "audio_branch.layers.0.blocks.0.norm1.weight", "audio_branch.layers.0.blocks.0.norm1.bias", "audio_branch.layers.0.blocks.0.attn.relative_position_bias_table", "audio_branch.layers.0.blocks.0.attn.relative_position_index", "audio_branch.layers.0.blocks.0.attn.qkv.weight", "audio_branch.layers.0.blocks.0.attn.qkv.bias", "audio_branch.layers.0.blocks.0.attn.proj.weight", "audio_branch.layers.0.blocks.0.attn.proj.bias", "audio_branch.layers.0.blocks.0.norm2.weight", "audio_branch.layers.0.blocks.0.norm2.bias", "audio_branch.layers.0.blocks.0.mlp.fc1.weight", "audio_branch.layers.0.blocks.0.mlp.fc1.bias", "audio_branch.layers.0.blocks.0.mlp.fc2.weight", "audio_branch.layers.0.blocks.0.mlp.fc2.bias", "audio_branch.layers.0.blocks.1.attn_mask", "audio_branch.layers.0.blocks.1.norm1.weight", "audio_branch.layers.0.blocks.1.norm1.bias", "audio_branch.layers.0.blocks.1.attn.relative_position_bias_table", "audio_branch.layers.0.blocks.1.attn.relative_position_index", "audio_branch.layers.0.blocks.1.attn.qkv.weight", "audio_branch.layers.0.blocks.1.attn.qkv.bias", "audio_branch.layers.0.blocks.1.attn.proj.weight", "audio_branch.layers.0.blocks.1.attn.proj.bias", "audio_branch.layers.0.blocks.1.norm2.weight", "audio_branch.layers.0.blocks.1.norm2.bias", "audio_branch.layers.0.blocks.1.mlp.fc1.weight", "audio_branch.layers.0.blocks.1.mlp.fc1.bias", "audio_branch.layers.0.blocks.1.mlp.fc2.weight", "audio_branch.layers.0.blocks.1.mlp.fc2.bias", "audio_branch.layers.0.downsample.reduction.weight", "audio_branch.layers.0.downsample.norm.weight", "audio_branch.layers.0.downsample.norm.bias", "audio_branch.layers.1.blocks.0.norm1.weight", "audio_branch.layers.1.blocks.0.norm1.bias", "audio_branch.layers.1.blocks.0.attn.relative_position_bias_table", "audio_branch.layers.1.blocks.0.attn.relative_position_index", "audio_branch.layers.1.blocks.0.attn.qkv.weight", "audio_branch.layers.1.blocks.0.attn.qkv.bias", "audio_branch.layers.1.blocks.0.attn.proj.weight", "audio_branch.layers.1.blocks.0.attn.proj.bias", "audio_branch.layers.1.blocks.0.norm2.weight", "audio_branch.layers.1.blocks.0.norm2.bias", "audio_branch.layers.1.blocks.0.mlp.fc1.weight", "audio_branch.layers.1.blocks.0.mlp.fc1.bias", "audio_branch.layers.1.blocks.0.mlp.fc2.weight", "audio_branch.layers.1.blocks.0.mlp.fc2.bias", "audio_branch.layers.1.blocks.1.attn_mask", "audio_branch.layers.1.blocks.1.norm1.weight", "audio_branch.layers.1.blocks.1.norm1.bias", "audio_branch.layers.1.blocks.1.attn.relative_position_bias_table", "audio_branch.layers.1.blocks.1.attn.relative_position_index", "audio_branch.layers.1.blocks.1.attn.qkv.weight", "audio_branch.layers.1.blocks.1.attn.qkv.bias", "audio_branch.layers.1.blocks.1.attn.proj.weight", "audio_branch.layers.1.blocks.1.attn.proj.bias", "audio_branch.layers.1.blocks.1.norm2.weight", "audio_branch.layers.1.blocks.1.norm2.bias", "audio_branch.layers.1.blocks.1.mlp.fc1.weight", "audio_branch.layers.1.blocks.1.mlp.fc1.bias", "audio_branch.layers.1.blocks.1.mlp.fc2.weight", "audio_branch.layers.1.blocks.1.mlp.fc2.bias", "audio_branch.layers.1.downsample.reduction.weight", "audio_branch.layers.1.downsample.norm.weight", "audio_branch.layers.1.downsample.norm.bias", "audio_branch.layers.2.blocks.0.norm1.weight", "audio_branch.layers.2.blocks.0.norm1.bias", "audio_branch.layers.2.blocks.0.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.0.attn.relative_position_index", "audio_branch.layers.2.blocks.0.attn.qkv.weight", "audio_branch.layers.2.blocks.0.attn.qkv.bias", "audio_branch.layers.2.blocks.0.attn.proj.weight", "audio_branch.layers.2.blocks.0.attn.proj.bias", "audio_branch.layers.2.blocks.0.norm2.weight", "audio_branch.layers.2.blocks.0.norm2.bias", "audio_branch.layers.2.blocks.0.mlp.fc1.weight", "audio_branch.layers.2.blocks.0.mlp.fc1.bias", "audio_branch.layers.2.blocks.0.mlp.fc2.weight", "audio_branch.layers.2.blocks.0.mlp.fc2.bias", "audio_branch.layers.2.blocks.1.attn_mask", "audio_branch.layers.2.blocks.1.norm1.weight", "audio_branch.layers.2.blocks.1.norm1.bias", "audio_branch.layers.2.blocks.1.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.1.attn.relative_position_index", "audio_branch.layers.2.blocks.1.attn.qkv.weight", "audio_branch.layers.2.blocks.1.attn.qkv.bias", "audio_branch.layers.2.blocks.1.attn.proj.weight", "audio_branch.layers.2.blocks.1.attn.proj.bias", "audio_branch.layers.2.blocks.1.norm2.weight", "audio_branch.layers.2.blocks.1.norm2.bias", "audio_branch.layers.2.blocks.1.mlp.fc1.weight", "audio_branch.layers.2.blocks.1.mlp.fc1.bias", "audio_branch.layers.2.blocks.1.mlp.fc2.weight", "audio_branch.layers.2.blocks.1.mlp.fc2.bias", "audio_branch.layers.2.blocks.2.norm1.weight", "audio_branch.layers.2.blocks.2.norm1.bias", "audio_branch.layers.2.blocks.2.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.2.attn.relative_position_index", "audio_branch.layers.2.blocks.2.attn.qkv.weight", "audio_branch.layers.2.blocks.2.attn.qkv.bias", "audio_branch.layers.2.blocks.2.attn.proj.weight", "audio_branch.layers.2.blocks.2.attn.proj.bias", "audio_branch.layers.2.blocks.2.norm2.weight", "audio_branch.layers.2.blocks.2.norm2.bias", "audio_branch.layers.2.blocks.2.mlp.fc1.weight", "audio_branch.layers.2.blocks.2.mlp.fc1.bias", "audio_branch.layers.2.blocks.2.mlp.fc2.weight", "audio_branch.layers.2.blocks.2.mlp.fc2.bias", "audio_branch.layers.2.blocks.3.attn_mask", "audio_branch.layers.2.blocks.3.norm1.weight", "audio_branch.layers.2.blocks.3.norm1.bias", "audio_branch.layers.2.blocks.3.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.3.attn.relative_position_index", "audio_branch.layers.2.blocks.3.attn.qkv.weight", "audio_branch.layers.2.blocks.3.attn.qkv.bias", "audio_branch.layers.2.blocks.3.attn.proj.weight", "audio_branch.layers.2.blocks.3.attn.proj.bias", "audio_branch.layers.2.blocks.3.norm2.weight", "audio_branch.layers.2.blocks.3.norm2.bias", "audio_branch.layers.2.blocks.3.mlp.fc1.weight", "audio_branch.layers.2.blocks.3.mlp.fc1.bias", "audio_branch.layers.2.blocks.3.mlp.fc2.weight", "audio_branch.layers.2.blocks.3.mlp.fc2.bias", "audio_branch.layers.2.blocks.4.norm1.weight", "audio_branch.layers.2.blocks.4.norm1.bias", "audio_branch.layers.2.blocks.4.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.4.attn.relative_position_index", "audio_branch.layers.2.blocks.4.attn.qkv.weight", "audio_branch.layers.2.blocks.4.attn.qkv.bias", "audio_branch.layers.2.blocks.4.attn.proj.weight", "audio_branch.layers.2.blocks.4.attn.proj.bias", "audio_branch.layers.2.blocks.4.norm2.weight", "audio_branch.layers.2.blocks.4.norm2.bias", "audio_branch.layers.2.blocks.4.mlp.fc1.weight", "audio_branch.layers.2.blocks.4.mlp.fc1.bias", "audio_branch.layers.2.blocks.4.mlp.fc2.weight", "audio_branch.layers.2.blocks.4.mlp.fc2.bias", "audio_branch.layers.2.blocks.5.attn_mask", "audio_branch.layers.2.blocks.5.norm1.weight", "audio_branch.layers.2.blocks.5.norm1.bias", "audio_branch.layers.2.blocks.5.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.5.attn.relative_position_index", "audio_branch.layers.2.blocks.5.attn.qkv.weight", "audio_branch.layers.2.blocks.5.attn.qkv.bias", "audio_branch.layers.2.blocks.5.attn.proj.weight", "audio_branch.layers.2.blocks.5.attn.proj.bias", "audio_branch.layers.2.blocks.5.norm2.weight", "audio_branch.layers.2.blocks.5.norm2.bias", "audio_branch.layers.2.blocks.5.mlp.fc1.weight", "audio_branch.layers.2.blocks.5.mlp.fc1.bias", "audio_branch.layers.2.blocks.5.mlp.fc2.weight", "audio_branch.layers.2.blocks.5.mlp.fc2.bias", "audio_branch.layers.2.downsample.reduction.weight", "audio_branch.layers.2.downsample.norm.weight", "audio_branch.layers.2.downsample.norm.bias", "audio_branch.layers.3.blocks.0.norm1.weight", "audio_branch.layers.3.blocks.0.norm1.bias", "audio_branch.layers.3.blocks.0.attn.relative_position_bias_table", "audio_branch.layers.3.blocks.0.attn.relative_position_index", "audio_branch.layers.3.blocks.0.attn.qkv.weight", "audio_branch.layers.3.blocks.0.attn.qkv.bias", "audio_branch.layers.3.blocks.0.attn.proj.weight", "audio_branch.layers.3.blocks.0.attn.proj.bias", "audio_branch.layers.3.blocks.0.norm2.weight", "audio_branch.layers.3.blocks.0.norm2.bias", "audio_branch.layers.3.blocks.0.mlp.fc1.weight", "audio_branch.layers.3.blocks.0.mlp.fc1.bias", "audio_branch.layers.3.blocks.0.mlp.fc2.weight", "audio_branch.layers.3.blocks.0.mlp.fc2.bias", "audio_branch.layers.3.blocks.1.norm1.weight", "audio_branch.layers.3.blocks.1.norm1.bias", "audio_branch.layers.3.blocks.1.attn.relative_position_bias_table", "audio_branch.layers.3.blocks.1.attn.relative_position_index", "audio_branch.layers.3.blocks.1.attn.qkv.weight", "audio_branch.layers.3.blocks.1.attn.qkv.bias", "audio_branch.layers.3.blocks.1.attn.proj.weight", "audio_branch.layers.3.blocks.1.attn.proj.bias", "audio_branch.layers.3.blocks.1.norm2.weight", "audio_branch.layers.3.blocks.1.norm2.bias", "audio_branch.layers.3.blocks.1.mlp.fc1.weight", "audio_branch.layers.3.blocks.1.mlp.fc1.bias", "audio_branch.layers.3.blocks.1.mlp.fc2.weight", "audio_branch.layers.3.blocks.1.mlp.fc2.bias", "audio_branch.norm.weight", "audio_branch.norm.bias", "audio_branch.tscam_conv.weight", "audio_branch.tscam_conv.bias", "audio_branch.head.weight", "audio_branch.head.bias". size mismatch for audio_projection.0.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([512, 2048]).

And here is my running script:
cd /data/jiacheng/CLAP/src/laion_clap/ export HF_HOME=~/.cache/huggingface/hub/hub/ python -m evaluate.eval_linear_probe \ --save-frequency 50 \ --save-top-performance 3 \ --save-most-recent \ --dataset-type="webdataset" \ --datasetpath="/data/jiacheng/train3/" \ --precision="fp32" \ --warmup 0 \ --batch-size=160 \ --lr=1e-4 \ --wd=0.1 \ --epochs=100 \ --workers=4 \ --use-bn-sync \ --freeze-text \ --amodel PANN-14 \ --tmodel roberta \ --report-to "wandb" \ --wandb-notes "10.14-finetune-esc50" \ --datasetnames "esc50" \ --datasetinfos "train" \ --seed 3407 \ --logs /data/jiacheng/CLAP/clap_logs \ --gather-with-grad \ --lp-loss="ce" \ --lp-metrics="acc" \ --lp-lr=1e-4 \ --lp-mlp \ --class-label-path="/data/jiacheng/CLAP/class_labels/ESC50_class_labels_indices_space.json" \ --openai-model-cache-dir ~/.cache/huggingface/hub/hub \ --pretrained="/data/jiacheng/CLAP/laion_clap/630k/" \ --data-filling "repeatpad" \ --data-truncating "rand_trunc" \ --optimizer "adam"

I understand this error. It means that the shape of the 'audio_projection.0.weight' layer in the checkpoint is [512, 768], which mismatches the projection layer in the audio encoder, where it is expected to be [512, 2048].

Do you have any suggestion about how to fix this error? btw I loaded the ckpt here, and use the esc50 dataset here. Thanks a lot. @lukewys

@kayleeliyx
Copy link

Hey! I met the same issue here. may i ask how did you solve it?

@tbrouns
Copy link

tbrouns commented Oct 29, 2024

This is probably because the pretrained model that you're giving as an argument (--pretrained="/data/jiacheng/CLAP/laion_clap/630k/") is based on the HTSAT-tiny audio encoder, which has an output dimension of 768.

Whereas the model graph that is being loaded by the finetune-esc50.sh script has the PANN-14 architecture, which has a dimension of 2048. See the --amodel PANN-14 argument in the finetune-esc50.sh script.

You can check paragraph 3.3 in the original paper for the model dimensions:
https://arxiv.org/pdf/2211.06687

This is why the original script uses:
--pretrained="/fsx/clap_logs/2022_10_14-04_05_14-model_PANN-14-lr_0.0001-b_160-j_6-p_fp32/checkpoints"
(which is a checkpoint based on the PANN-14 audio encoder)

You might get away with using these arguments:

--pretrained="/data/jiacheng/CLAP/laion_clap/630k/"`
--amodel HTSAT-tiny

But that's something to test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants