Skip to content

Commit

Permalink
[fix] get vocab_size from logits
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed Aug 6, 2024
1 parent def3be9 commit 87a0922
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions utilization/model/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,11 @@ def get_ppl_with_cache(
exact_match: bool = False,
) -> List[Tuple[float, int]]:
logits, labels, input_lengths = self.get_cache(batched_targets, prefix_cache, return_caches=False)
vocab_size = logits.shape[-1]
last_logits = torch.cat(prefix_cache.next_logits, dim=0).to(logits.device)
shift_logits = torch.cat([last_logits, logits[:, :-1]], dim=-2)
labels[labels == self.tokenizer.pad_token_id] = -100
probs = self.loss_fct(shift_logits.view(-1, self.model.config.vocab_size),
probs = self.loss_fct(shift_logits.view(-1, vocab_size),
labels.view(-1)).view(labels.size(0), -1)

if exact_match:
Expand Down Expand Up @@ -508,10 +509,11 @@ def get_ppl(
logits = self.model(
input_ids=batched_encodings["input_ids"], attention_mask=batched_encodings["attention_mask"]
).logits
vocab_size = logits.shape[-1]
shift_logits = logits.detach()[:, :-1].contiguous()
shift_labels = batched_encodings["input_ids"][:, 1:].contiguous()
shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
probs = self.loss_fct(shift_logits.view(-1, self.model.config.vocab_size),
probs = self.loss_fct(shift_logits.view(-1, vocab_size),
shift_labels.view(-1)).view(shift_labels.size(0), -1).cpu()

tgt_starts = [None] * len(batched_inputs)
Expand Down

0 comments on commit 87a0922

Please sign in to comment.