-
Notifications
You must be signed in to change notification settings - Fork 337
/
Copy pathlogit_process.py
65 lines (51 loc) · 2.12 KB
/
logit_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import List
import numpy as np
from math_utils import softmax
def StopWordsLogitsProcessor(scores, input_ids):
eos_token_id = 151643
stop_words_ids = [[151645], [151644]]
def tokens_match(prev_tokens: np.ndarray, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
stopped_samples = []
for prev_input_ids_slice in input_ids:
match = False
for stop_token_seq in stop_words_ids:
if tokens_match(prev_input_ids_slice, stop_token_seq):
# if tokens do not match continue
match = True
break
stopped_samples.append(match)
for i, should_stop in enumerate(stopped_samples):
if should_stop:
scores[i, eos_token_id] = float(2**15)
return scores
def TopPLogitsWarper(scores, top_p):
sorted_indices = np.argsort(scores)
sorted_logits = np.take_along_axis(scores, sorted_indices, axis=-1)
cumulative_probs = np.cumsum(softmax(sorted_logits, axis=-1), axis=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
min_tokens_to_keep = 1
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = np.copy(sorted_indices_to_remove)
np.put_along_axis(
indices_to_remove, sorted_indices, sorted_indices_to_remove, axis=1
)
scores_processed = np.where(indices_to_remove, -np.inf, scores)
return scores_processed
def logits_processor(input_ids, scores, top_p=0.5):
scores = StopWordsLogitsProcessor(scores, input_ids)
scores = TopPLogitsWarper(scores, top_p)
return scores