From f90549634a695efad95d643ffbaf7b0b46e6e248 Mon Sep 17 00:00:00 2001 From: Nintorac Dev Date: Fri, 12 May 2023 13:12:28 +1000 Subject: [PATCH] feat: add stopword checker + iterable generate function --- rwkv_pip_package/src/rwkv/utils.py | 96 +++++++++++++++++++++++++++--- 1 file changed, 89 insertions(+), 7 deletions(-) diff --git a/rwkv_pip_package/src/rwkv/utils.py b/rwkv_pip_package/src/rwkv/utils.py index b69de621..4af050c4 100644 --- a/rwkv_pip_package/src/rwkv/utils.py +++ b/rwkv_pip_package/src/rwkv/utils.py @@ -7,8 +7,30 @@ import torch from torch.nn import functional as F + +def end_overlap(a, b): + for i in reversed(range(1, len(a) + 1)): + if b.startswith(a[-i:]): + return i + return 0 + class PIPELINE_ARGS(): - def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, token_ban=[], token_stop=[], chunk_len=256): + def __init__(self, + temperature=1.0, + top_p=0.85, + top_k=0, + alpha_frequency=0.2, + alpha_presence=0.2, + token_ban=None, + token_stop=None, + stop_words=None, + chunk_len=256 + ): + + token_ban = token_ban or [] + token_stop = token_stop or [] + stop_words = stop_words or [] + self.temperature = temperature self.top_p = top_p self.top_k = top_k @@ -16,6 +38,7 @@ def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, al self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3) self.token_ban = token_ban # ban the generation of some tokens self.token_stop = token_stop # stop generation whenever you see any token here + self.stop_words = stop_words # stop generation whenever you see any token here self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower) class PIPELINE(): @@ -77,12 +100,23 @@ def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0): probs = probs ** (1.0 / temperature) out = torch.multinomial(probs, num_samples=1)[0] return int(out) - - def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None): + + def generate(self, *args, callback=None, **kwargs): + outstr = [] + for delta in self.igenerate(*args, **kwargs): + outstr += [delta] + if callback: + callback(delta) + return ''.join(outstr) + + def igenerate(self, ctx, token_count=100, args=PIPELINE_ARGS(), state=None): all_tokens = [] out_last = 0 out_str = '' occurrence = {} + + stopword_checker = self.check_stopwords(args.stop_words) + next(stopword_checker) for i in range(token_count): # forward & adjust prob. @@ -108,9 +142,57 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st # output tmp = self.decode(all_tokens[out_last:]) + if len(all_tokens)==1: + tmp = tmp[1:] # strip leading space + if tmp == '': + continue if '\ufffd' not in tmp: # is valid utf-8 string? - if callback: - callback(tmp) - out_str += tmp + + try: + tmp = stopword_checker.send(tmp) + except StopIteration: + break out_last = i + 1 - return out_str + + if tmp is None: + continue + yield tmp + out_str += tmp + out_last = i + 1 + + @staticmethod + def check_stopwords(stop_words): + + longest_stopword = 0 if len(stop_words)==0 else max(map(len, stop_words)) + chunk = "" + delta = True + # yield + to_yield = None + while delta: + delta = yield to_yield + chunk = chunk + delta + + if longest_stopword == 0: + # nothing to check just passthrough + to_yield = delta + continue + if chunk == '': + to_yield = None + continue + if any(map(lambda stop_word: chunk.startswith(stop_word), stop_words)): + return + + if start_idx := max(map(lambda stop_word: end_overlap(chunk, stop_word), stop_words)): + if start_idx > longest_stopword: + start_idx = longest_stopword # can no longer be a stopword so cut it down + good, chunk = chunk[:-start_idx], chunk[-start_idx:] + if good: + to_yield = good + continue + + to_yield = None + continue + + out = chunk + chunk = "" + to_yield = out \ No newline at end of file