Skip to content

Commit

Permalink
feat: add stopword checker + iterable generate function
Browse files Browse the repository at this point in the history
  • Loading branch information
Nintorac committed May 12, 2023
1 parent 18c847d commit f905496
Showing 1 changed file with 89 additions and 7 deletions.
96 changes: 89 additions & 7 deletions rwkv_pip_package/src/rwkv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,38 @@
import torch
from torch.nn import functional as F


def end_overlap(a, b):
for i in reversed(range(1, len(a) + 1)):
if b.startswith(a[-i:]):
return i
return 0

class PIPELINE_ARGS():
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, token_ban=[], token_stop=[], chunk_len=256):
def __init__(self,
temperature=1.0,
top_p=0.85,
top_k=0,
alpha_frequency=0.2,
alpha_presence=0.2,
token_ban=None,
token_stop=None,
stop_words=None,
chunk_len=256
):

token_ban = token_ban or []
token_stop = token_stop or []
stop_words = stop_words or []

self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
self.token_ban = token_ban # ban the generation of some tokens
self.token_stop = token_stop # stop generation whenever you see any token here
self.stop_words = stop_words # stop generation whenever you see any token here
self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)

class PIPELINE():
Expand Down Expand Up @@ -77,12 +100,23 @@ def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = probs ** (1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)

def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):

def generate(self, *args, callback=None, **kwargs):
outstr = []
for delta in self.igenerate(*args, **kwargs):
outstr += [delta]
if callback:
callback(delta)
return ''.join(outstr)

def igenerate(self, ctx, token_count=100, args=PIPELINE_ARGS(), state=None):
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}

stopword_checker = self.check_stopwords(args.stop_words)
next(stopword_checker)
for i in range(token_count):

# forward & adjust prob.
Expand All @@ -108,9 +142,57 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st

# output
tmp = self.decode(all_tokens[out_last:])
if len(all_tokens)==1:
tmp = tmp[1:] # strip leading space
if tmp == '':
continue
if '\ufffd' not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp

try:
tmp = stopword_checker.send(tmp)
except StopIteration:
break
out_last = i + 1
return out_str

if tmp is None:
continue
yield tmp
out_str += tmp
out_last = i + 1

@staticmethod
def check_stopwords(stop_words):

longest_stopword = 0 if len(stop_words)==0 else max(map(len, stop_words))
chunk = ""
delta = True
# yield
to_yield = None
while delta:
delta = yield to_yield
chunk = chunk + delta

if longest_stopword == 0:
# nothing to check just passthrough
to_yield = delta
continue
if chunk == '':
to_yield = None
continue
if any(map(lambda stop_word: chunk.startswith(stop_word), stop_words)):
return

if start_idx := max(map(lambda stop_word: end_overlap(chunk, stop_word), stop_words)):
if start_idx > longest_stopword:
start_idx = longest_stopword # can no longer be a stopword so cut it down
good, chunk = chunk[:-start_idx], chunk[-start_idx:]
if good:
to_yield = good
continue

to_yield = None
continue

out = chunk
chunk = ""
to_yield = out

0 comments on commit f905496

Please sign in to comment.