Skip to content

Commit

Permalink
--other=remove print info
Browse files Browse the repository at this point in the history
  • Loading branch information
echoht committed May 6, 2023
1 parent 026adca commit 77d7a60
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def rrhf_loss(self, scores, idxs, rw_scores):
# rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1) # b * b # batch * cand
# aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0]
# return -diff[aval].sum()
print(scores.shape) # score shape (batch * cand)

cand = rw_scores.shape[1]
new_scores = scores.reshape(-1, cand) # batch * cand
diff = new_scores.unsqueeze(1) - new_scores.unsqueeze(-1) # batch * cand * cand
Expand All @@ -263,11 +263,9 @@ def sft_loss(self, logit_label, idxs, rw_scores): # (batch * cand) *L
max_idx = torch.argmax(rw_scores, dim=1) # batch
# 每个task的response个数均相同
cand = rw_scores.shape[1]
print("logit_label:", logit_label.shape)
logit_label_batch = torch.reshape(logit_label, (-1, cand, logit_label.shape[-1])) # batch * cand * L
expert_response_logit_label = logit_label_batch[:1, max_idx].squeeze() # batch * L
return -torch.sum(expert_response_logit_label.mean())
#return -logit_label[max_idx].mean()

def compute_loss(self, model, inputs, return_outputs=False):
if self.args.only_use_provide:
Expand Down

0 comments on commit 77d7a60

Please sign in to comment.