You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I understand this error. It means that the shape of the 'audio_projection.0.weight' layer in the checkpoint is [512, 768], which mismatches the projection layer in the audio encoder, where it is expected to be [512, 2048].
Do you have any suggestion about how to fix this error? btw I loaded the ckpt here, and use the esc50 dataset here. Thanks a lot. @lukewys
The text was updated successfully, but these errors were encountered:
This is probably because the pretrained model that you're giving as an argument (--pretrained="/data/jiacheng/CLAP/laion_clap/630k/") is based on the HTSAT-tiny audio encoder, which has an output dimension of 768.
Whereas the model graph that is being loaded by the finetune-esc50.sh script has the PANN-14 architecture, which has a dimension of 2048. See the --amodel PANN-14 argument in the finetune-esc50.sh script.
This is why the original script uses: --pretrained="/fsx/clap_logs/2022_10_14-04_05_14-model_PANN-14-lr_0.0001-b_160-j_6-p_fp32/checkpoints"
(which is a checkpoint based on the PANN-14 audio encoder)
Thanks for this solid work. I have a question during finetuning the 'esc50' dataset. I try to run fintune-esc50.sh but got error below:
File "/data/jiacheng/CLAP/src/laion_clap/clap_module/factory.py", line 155, in create_model model.load_state_dict(ckpt) File "/home/jiacheng/anaconda3/envs/clap/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for CLAP: Missing key(s) in state_dict: "audio_branch.conv_block1.conv1.weight", "audio_branch.conv_block1.conv2.weight", "audio_branch.conv_block1.bn1.weight", "audio_branch.conv_block1.bn1.bias", "audio_branch.conv_block1.bn1.running_mean", "audio_branch.conv_block1.bn1.running_var", "audio_branch.conv_block1.bn2.weight", "audio_branch.conv_block1.bn2.bias", "audio_branch.conv_block1.bn2.running_mean", "audio_branch.conv_block1.bn2.running_var", "audio_branch.conv_block2.conv1.weight", "audio_branch.conv_block2.conv2.weight", "audio_branch.conv_block2.bn1.weight", "audio_branch.conv_block2.bn1.bias", "audio_branch.conv_block2.bn1.running_mean", "audio_branch.conv_block2.bn1.running_var", "audio_branch.conv_block2.bn2.weight", "audio_branch.conv_block2.bn2.bias", "audio_branch.conv_block2.bn2.running_mean", "audio_branch.conv_block2.bn2.running_var", "audio_branch.conv_block3.conv1.weight", "audio_branch.conv_block3.conv2.weight", "audio_branch.conv_block3.bn1.weight", "audio_branch.conv_block3.bn1.bias", "audio_branch.conv_block3.bn1.running_mean", "audio_branch.conv_block3.bn1.running_var", "audio_branch.conv_block3.bn2.weight", "audio_branch.conv_block3.bn2.bias", "audio_branch.conv_block3.bn2.running_mean", "audio_branch.conv_block3.bn2.running_var", "audio_branch.conv_block4.conv1.weight", "audio_branch.conv_block4.conv2.weight", "audio_branch.conv_block4.bn1.weight", "audio_branch.conv_block4.bn1.bias", "audio_branch.conv_block4.bn1.running_mean", "audio_branch.conv_block4.bn1.running_var", "audio_branch.conv_block4.bn2.weight", "audio_branch.conv_block4.bn2.bias", "audio_branch.conv_block4.bn2.running_mean", "audio_branch.conv_block4.bn2.running_var", "audio_branch.conv_block5.conv1.weight", "audio_branch.conv_block5.conv2.weight", "audio_branch.conv_block5.bn1.weight", "audio_branch.conv_block5.bn1.bias", "audio_branch.conv_block5.bn1.running_mean", "audio_branch.conv_block5.bn1.running_var", "audio_branch.conv_block5.bn2.weight", "audio_branch.conv_block5.bn2.bias", "audio_branch.conv_block5.bn2.running_mean", "audio_branch.conv_block5.bn2.running_var", "audio_branch.conv_block6.conv1.weight", "audio_branch.conv_block6.conv2.weight", "audio_branch.conv_block6.bn1.weight", "audio_branch.conv_block6.bn1.bias", "audio_branch.conv_block6.bn1.running_mean", "audio_branch.conv_block6.bn1.running_var", "audio_branch.conv_block6.bn2.weight", "audio_branch.conv_block6.bn2.bias", "audio_branch.conv_block6.bn2.running_mean", "audio_branch.conv_block6.bn2.running_var", "audio_branch.fc1.weight", "audio_branch.fc1.bias", "audio_branch.fc_audioset.weight", "audio_branch.fc_audioset.bias". Unexpected key(s) in state_dict: "audio_branch.patch_embed.proj.weight", "audio_branch.patch_embed.proj.bias", "audio_branch.patch_embed.norm.weight", "audio_branch.patch_embed.norm.bias", "audio_branch.layers.0.blocks.0.norm1.weight", "audio_branch.layers.0.blocks.0.norm1.bias", "audio_branch.layers.0.blocks.0.attn.relative_position_bias_table", "audio_branch.layers.0.blocks.0.attn.relative_position_index", "audio_branch.layers.0.blocks.0.attn.qkv.weight", "audio_branch.layers.0.blocks.0.attn.qkv.bias", "audio_branch.layers.0.blocks.0.attn.proj.weight", "audio_branch.layers.0.blocks.0.attn.proj.bias", "audio_branch.layers.0.blocks.0.norm2.weight", "audio_branch.layers.0.blocks.0.norm2.bias", "audio_branch.layers.0.blocks.0.mlp.fc1.weight", "audio_branch.layers.0.blocks.0.mlp.fc1.bias", "audio_branch.layers.0.blocks.0.mlp.fc2.weight", "audio_branch.layers.0.blocks.0.mlp.fc2.bias", "audio_branch.layers.0.blocks.1.attn_mask", "audio_branch.layers.0.blocks.1.norm1.weight", "audio_branch.layers.0.blocks.1.norm1.bias", "audio_branch.layers.0.blocks.1.attn.relative_position_bias_table", "audio_branch.layers.0.blocks.1.attn.relative_position_index", "audio_branch.layers.0.blocks.1.attn.qkv.weight", "audio_branch.layers.0.blocks.1.attn.qkv.bias", "audio_branch.layers.0.blocks.1.attn.proj.weight", "audio_branch.layers.0.blocks.1.attn.proj.bias", "audio_branch.layers.0.blocks.1.norm2.weight", "audio_branch.layers.0.blocks.1.norm2.bias", "audio_branch.layers.0.blocks.1.mlp.fc1.weight", "audio_branch.layers.0.blocks.1.mlp.fc1.bias", "audio_branch.layers.0.blocks.1.mlp.fc2.weight", "audio_branch.layers.0.blocks.1.mlp.fc2.bias", "audio_branch.layers.0.downsample.reduction.weight", "audio_branch.layers.0.downsample.norm.weight", "audio_branch.layers.0.downsample.norm.bias", "audio_branch.layers.1.blocks.0.norm1.weight", "audio_branch.layers.1.blocks.0.norm1.bias", "audio_branch.layers.1.blocks.0.attn.relative_position_bias_table", "audio_branch.layers.1.blocks.0.attn.relative_position_index", "audio_branch.layers.1.blocks.0.attn.qkv.weight", "audio_branch.layers.1.blocks.0.attn.qkv.bias", "audio_branch.layers.1.blocks.0.attn.proj.weight", "audio_branch.layers.1.blocks.0.attn.proj.bias", "audio_branch.layers.1.blocks.0.norm2.weight", "audio_branch.layers.1.blocks.0.norm2.bias", "audio_branch.layers.1.blocks.0.mlp.fc1.weight", "audio_branch.layers.1.blocks.0.mlp.fc1.bias", "audio_branch.layers.1.blocks.0.mlp.fc2.weight", "audio_branch.layers.1.blocks.0.mlp.fc2.bias", "audio_branch.layers.1.blocks.1.attn_mask", "audio_branch.layers.1.blocks.1.norm1.weight", "audio_branch.layers.1.blocks.1.norm1.bias", "audio_branch.layers.1.blocks.1.attn.relative_position_bias_table", "audio_branch.layers.1.blocks.1.attn.relative_position_index", "audio_branch.layers.1.blocks.1.attn.qkv.weight", "audio_branch.layers.1.blocks.1.attn.qkv.bias", "audio_branch.layers.1.blocks.1.attn.proj.weight", "audio_branch.layers.1.blocks.1.attn.proj.bias", "audio_branch.layers.1.blocks.1.norm2.weight", "audio_branch.layers.1.blocks.1.norm2.bias", "audio_branch.layers.1.blocks.1.mlp.fc1.weight", "audio_branch.layers.1.blocks.1.mlp.fc1.bias", "audio_branch.layers.1.blocks.1.mlp.fc2.weight", "audio_branch.layers.1.blocks.1.mlp.fc2.bias", "audio_branch.layers.1.downsample.reduction.weight", "audio_branch.layers.1.downsample.norm.weight", "audio_branch.layers.1.downsample.norm.bias", "audio_branch.layers.2.blocks.0.norm1.weight", "audio_branch.layers.2.blocks.0.norm1.bias", "audio_branch.layers.2.blocks.0.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.0.attn.relative_position_index", "audio_branch.layers.2.blocks.0.attn.qkv.weight", "audio_branch.layers.2.blocks.0.attn.qkv.bias", "audio_branch.layers.2.blocks.0.attn.proj.weight", "audio_branch.layers.2.blocks.0.attn.proj.bias", "audio_branch.layers.2.blocks.0.norm2.weight", "audio_branch.layers.2.blocks.0.norm2.bias", "audio_branch.layers.2.blocks.0.mlp.fc1.weight", "audio_branch.layers.2.blocks.0.mlp.fc1.bias", "audio_branch.layers.2.blocks.0.mlp.fc2.weight", "audio_branch.layers.2.blocks.0.mlp.fc2.bias", "audio_branch.layers.2.blocks.1.attn_mask", "audio_branch.layers.2.blocks.1.norm1.weight", "audio_branch.layers.2.blocks.1.norm1.bias", "audio_branch.layers.2.blocks.1.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.1.attn.relative_position_index", "audio_branch.layers.2.blocks.1.attn.qkv.weight", "audio_branch.layers.2.blocks.1.attn.qkv.bias", "audio_branch.layers.2.blocks.1.attn.proj.weight", "audio_branch.layers.2.blocks.1.attn.proj.bias", "audio_branch.layers.2.blocks.1.norm2.weight", "audio_branch.layers.2.blocks.1.norm2.bias", "audio_branch.layers.2.blocks.1.mlp.fc1.weight", "audio_branch.layers.2.blocks.1.mlp.fc1.bias", "audio_branch.layers.2.blocks.1.mlp.fc2.weight", "audio_branch.layers.2.blocks.1.mlp.fc2.bias", "audio_branch.layers.2.blocks.2.norm1.weight", "audio_branch.layers.2.blocks.2.norm1.bias", "audio_branch.layers.2.blocks.2.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.2.attn.relative_position_index", "audio_branch.layers.2.blocks.2.attn.qkv.weight", "audio_branch.layers.2.blocks.2.attn.qkv.bias", "audio_branch.layers.2.blocks.2.attn.proj.weight", "audio_branch.layers.2.blocks.2.attn.proj.bias", "audio_branch.layers.2.blocks.2.norm2.weight", "audio_branch.layers.2.blocks.2.norm2.bias", "audio_branch.layers.2.blocks.2.mlp.fc1.weight", "audio_branch.layers.2.blocks.2.mlp.fc1.bias", "audio_branch.layers.2.blocks.2.mlp.fc2.weight", "audio_branch.layers.2.blocks.2.mlp.fc2.bias", "audio_branch.layers.2.blocks.3.attn_mask", "audio_branch.layers.2.blocks.3.norm1.weight", "audio_branch.layers.2.blocks.3.norm1.bias", "audio_branch.layers.2.blocks.3.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.3.attn.relative_position_index", "audio_branch.layers.2.blocks.3.attn.qkv.weight", "audio_branch.layers.2.blocks.3.attn.qkv.bias", "audio_branch.layers.2.blocks.3.attn.proj.weight", "audio_branch.layers.2.blocks.3.attn.proj.bias", "audio_branch.layers.2.blocks.3.norm2.weight", "audio_branch.layers.2.blocks.3.norm2.bias", "audio_branch.layers.2.blocks.3.mlp.fc1.weight", "audio_branch.layers.2.blocks.3.mlp.fc1.bias", "audio_branch.layers.2.blocks.3.mlp.fc2.weight", "audio_branch.layers.2.blocks.3.mlp.fc2.bias", "audio_branch.layers.2.blocks.4.norm1.weight", "audio_branch.layers.2.blocks.4.norm1.bias", "audio_branch.layers.2.blocks.4.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.4.attn.relative_position_index", "audio_branch.layers.2.blocks.4.attn.qkv.weight", "audio_branch.layers.2.blocks.4.attn.qkv.bias", "audio_branch.layers.2.blocks.4.attn.proj.weight", "audio_branch.layers.2.blocks.4.attn.proj.bias", "audio_branch.layers.2.blocks.4.norm2.weight", "audio_branch.layers.2.blocks.4.norm2.bias", "audio_branch.layers.2.blocks.4.mlp.fc1.weight", "audio_branch.layers.2.blocks.4.mlp.fc1.bias", "audio_branch.layers.2.blocks.4.mlp.fc2.weight", "audio_branch.layers.2.blocks.4.mlp.fc2.bias", "audio_branch.layers.2.blocks.5.attn_mask", "audio_branch.layers.2.blocks.5.norm1.weight", "audio_branch.layers.2.blocks.5.norm1.bias", "audio_branch.layers.2.blocks.5.attn.relative_position_bias_table", "audio_branch.layers.2.blocks.5.attn.relative_position_index", "audio_branch.layers.2.blocks.5.attn.qkv.weight", "audio_branch.layers.2.blocks.5.attn.qkv.bias", "audio_branch.layers.2.blocks.5.attn.proj.weight", "audio_branch.layers.2.blocks.5.attn.proj.bias", "audio_branch.layers.2.blocks.5.norm2.weight", "audio_branch.layers.2.blocks.5.norm2.bias", "audio_branch.layers.2.blocks.5.mlp.fc1.weight", "audio_branch.layers.2.blocks.5.mlp.fc1.bias", "audio_branch.layers.2.blocks.5.mlp.fc2.weight", "audio_branch.layers.2.blocks.5.mlp.fc2.bias", "audio_branch.layers.2.downsample.reduction.weight", "audio_branch.layers.2.downsample.norm.weight", "audio_branch.layers.2.downsample.norm.bias", "audio_branch.layers.3.blocks.0.norm1.weight", "audio_branch.layers.3.blocks.0.norm1.bias", "audio_branch.layers.3.blocks.0.attn.relative_position_bias_table", "audio_branch.layers.3.blocks.0.attn.relative_position_index", "audio_branch.layers.3.blocks.0.attn.qkv.weight", "audio_branch.layers.3.blocks.0.attn.qkv.bias", "audio_branch.layers.3.blocks.0.attn.proj.weight", "audio_branch.layers.3.blocks.0.attn.proj.bias", "audio_branch.layers.3.blocks.0.norm2.weight", "audio_branch.layers.3.blocks.0.norm2.bias", "audio_branch.layers.3.blocks.0.mlp.fc1.weight", "audio_branch.layers.3.blocks.0.mlp.fc1.bias", "audio_branch.layers.3.blocks.0.mlp.fc2.weight", "audio_branch.layers.3.blocks.0.mlp.fc2.bias", "audio_branch.layers.3.blocks.1.norm1.weight", "audio_branch.layers.3.blocks.1.norm1.bias", "audio_branch.layers.3.blocks.1.attn.relative_position_bias_table", "audio_branch.layers.3.blocks.1.attn.relative_position_index", "audio_branch.layers.3.blocks.1.attn.qkv.weight", "audio_branch.layers.3.blocks.1.attn.qkv.bias", "audio_branch.layers.3.blocks.1.attn.proj.weight", "audio_branch.layers.3.blocks.1.attn.proj.bias", "audio_branch.layers.3.blocks.1.norm2.weight", "audio_branch.layers.3.blocks.1.norm2.bias", "audio_branch.layers.3.blocks.1.mlp.fc1.weight", "audio_branch.layers.3.blocks.1.mlp.fc1.bias", "audio_branch.layers.3.blocks.1.mlp.fc2.weight", "audio_branch.layers.3.blocks.1.mlp.fc2.bias", "audio_branch.norm.weight", "audio_branch.norm.bias", "audio_branch.tscam_conv.weight", "audio_branch.tscam_conv.bias", "audio_branch.head.weight", "audio_branch.head.bias". size mismatch for audio_projection.0.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([512, 2048]).
And here is my running script:
cd /data/jiacheng/CLAP/src/laion_clap/ export HF_HOME=~/.cache/huggingface/hub/hub/ python -m evaluate.eval_linear_probe \ --save-frequency 50 \ --save-top-performance 3 \ --save-most-recent \ --dataset-type="webdataset" \ --datasetpath="/data/jiacheng/train3/" \ --precision="fp32" \ --warmup 0 \ --batch-size=160 \ --lr=1e-4 \ --wd=0.1 \ --epochs=100 \ --workers=4 \ --use-bn-sync \ --freeze-text \ --amodel PANN-14 \ --tmodel roberta \ --report-to "wandb" \ --wandb-notes "10.14-finetune-esc50" \ --datasetnames "esc50" \ --datasetinfos "train" \ --seed 3407 \ --logs /data/jiacheng/CLAP/clap_logs \ --gather-with-grad \ --lp-loss="ce" \ --lp-metrics="acc" \ --lp-lr=1e-4 \ --lp-mlp \ --class-label-path="/data/jiacheng/CLAP/class_labels/ESC50_class_labels_indices_space.json" \ --openai-model-cache-dir ~/.cache/huggingface/hub/hub \ --pretrained="/data/jiacheng/CLAP/laion_clap/630k/" \ --data-filling "repeatpad" \ --data-truncating "rand_trunc" \ --optimizer "adam"
I understand this error. It means that the shape of the 'audio_projection.0.weight' layer in the checkpoint is [512, 768], which mismatches the projection layer in the audio encoder, where it is expected to be [512, 2048].
Do you have any suggestion about how to fix this error? btw I loaded the ckpt here, and use the esc50 dataset here. Thanks a lot. @lukewys
The text was updated successfully, but these errors were encountered: