Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get embeddings error #4

Open
pillowill opened this issue Jun 29, 2024 · 2 comments
Open

get embeddings error #4

pillowill opened this issue Jun 29, 2024 · 2 comments

Comments

@pillowill
Copy link

dear author:
when i use the command :"python MLM_SFP.py --pretraining bert_mul_2.pth --data_embedding my_rna.fa --embedding_output rRNABert_emb.csv --batch 40"
i met the following errors:
RuntimeError: Error(s) in loading state_dict for BertForMaskedLM:
Missing key(s) in state_dict: "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.token_type_embeddings.weight", "bert.embeddings.LayerNorm.gamma", "bert.embeddings.LayerNorm.beta", "bert.encoder.layer.0.attention.selfattn.query.weight", "bert.encoder.layer.0.attention.selfattn.query.bias", "bert.encoder.layer.0.attention.selfa...

@sunyunlee
Copy link

sunyunlee commented Nov 20, 2024

Hi I am getting the same error message when trying to extract embeddings from the pre trained model without fine tuning it. I’m assuming it has to do with the discrepancy between the initialized model and the existing weights. Has the issue been addressed/fixed? Thanks in advance.

@sunyunlee
Copy link

sunyunlee commented Nov 20, 2024

I was able to figure out the issue. The issue is that the OrderedDict in the pretrained file has a different parameter names than the ones the Bert class object was expecting. It has an additional word.

import torch
from collections import OrderedDict

file_path = 'bert_mul_2.pth'

state_dict = torch.load(file_path, map_location="cpu")

new_state_dict = OrderedDict()

for key, value in state_dict.items():
    # Modify the key as needed
    new_key = ".".join(key.split(".")[1:])
    new_state_dict[new_key] = value.clone()

torch.save(new_state_dict, 'bert_mul_2_correction.pth')

for key in new_state_dict.keys():
    print(key)

ran this first to get a new weight file.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants