From c33ea3d792b14319f9a4f6702aa788d98dc3781c Mon Sep 17 00:00:00 2001 From: shounakb1 <32771603+shounakb1@users.noreply.github.com> Date: Mon, 28 Aug 2023 18:58:25 +0530 Subject: [PATCH] Update generation.py to add support for repetition penalty Added support for repetition penalty in generate method. --- llama/generation.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/llama/generation.py b/llama/generation.py index 508095b04..45df873de 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -110,6 +110,7 @@ def generate( temperature: float = 0.6, top_p: float = 0.9, logprobs: bool = False, + repetition_penalty: float = (1.0/0.9), echo: bool = False, ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: params = self.model.params @@ -133,6 +134,17 @@ def generate( input_text_mask = tokens != pad_id for cur_pos in range(min_prompt_len, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if repetition_penalty != 1.0: + logits_new = logits.clone() + batch_size = len(tokens) + for i in range(batch_size): + for token in set(tokens[i].tolist()): + # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if logits[0, i, token] < 0: + logits_new[0, i, token] = logits[0, i, token] * repetition_penalty + else: + logits_new[0, i, token] = logits[0, i, token] / repetition_penalty + logits = logits_new if logprobs: token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( input=logits.transpose(1, 2),