You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
src/modeling code:
if self.training:
# 只对positive计算generation loss
# generated_q_hidden = generated_q_hidden[:,None,:,:].view(bz, self.args.sample_num, -1, p_hidden.size(-1))[:,0,:,:] # [bz, seq_len, 768]
if mlm_labels is not None and mlm_labels['decoder_mlm_labels'] is not None:
mlm_loss += self.mlm_loss(generated_q_hidden, mlm_labels['decoder_mlm_labels']) # query生成loss
如何保证只对【正样本】计算mlm_loss?负样本是如何过滤的?
The text was updated successfully, but these errors were encountered:
src/modeling code:
if self.training:
# 只对positive计算generation loss
# generated_q_hidden = generated_q_hidden[:,None,:,:].view(bz, self.args.sample_num, -1, p_hidden.size(-1))[:,0,:,:] # [bz, seq_len, 768]
if mlm_labels is not None and mlm_labels['decoder_mlm_labels'] is not None:
mlm_loss += self.mlm_loss(generated_q_hidden, mlm_labels['decoder_mlm_labels']) # query生成loss
如何保证只对【正样本】计算mlm_loss?负样本是如何过滤的?
The text was updated successfully, but these errors were encountered: