Skip to content

Commit

Permalink
--other=fix index error
Browse files Browse the repository at this point in the history
  • Loading branch information
echoht committed May 6, 2023
1 parent 77d7a60 commit 701aaeb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def sft_loss(self, logit_label, idxs, rw_scores): # (batch * cand) *L
# 每个task的response个数均相同
cand = rw_scores.shape[1]
logit_label_batch = torch.reshape(logit_label, (-1, cand, logit_label.shape[-1])) # batch * cand * L
expert_response_logit_label = logit_label_batch[:1, max_idx].squeeze() # batch * L
expert_response_logit_label = torch.gather(logit_label_batch, dim=1, index=max_idx.view(-1, 1, 1).repeat(1, 1, logit_label_batch.size(-1))).squeeze() # batch * L
return -torch.sum(expert_response_logit_label.mean())

def compute_loss(self, model, inputs, return_outputs=False):
Expand Down

1 comment on commit 701aaeb

@Unkrible
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expert_response_logit_label = logit_label_batch[torch.arange(rw_scores.shape[0]), max_idx].squeeze()
也可以

Please sign in to comment.