Description
Hi, I've found a bug in the following method that is used for BERT-K and BERT-LS encoding sentences.
**** In "methods/bert.py/_logits_for_dropout_target" ****
with torch.no_grad():
embeddings = self.bert.embeddings(self.list_to_tensor(context_target_enc))
(Omitting Dropout procedure)
logits = self.bert_mlm(inputs_embeds=embeddings)[0][0, target_tok_off]
Here, you use "self.bert.embeddings" to generate input embeddings (i.e. "embeddings"), but this class method returns "token embeddings + token_type_embeddings + positional embedings". However, what self.bert_mlm takes as "inputs_embeds" is only the token embeddings. So I think your code adds "token_type_embeddings" and "positional embeddings" twice to the token embeddings. To fix this, I think the first line needs to be changed to "self.bert.embeddings.word_embeddings(self.list_to_tensor(context_target_enc))".
Best,
Takashi