forked from qwopqwop200/GPTQ-for-LLaMa
-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcustom_streamer.py
73 lines (60 loc) · 2.96 KB
/
custom_streamer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import TextStreamer
from multiprocessing import Queue
from typing import Optional
class TextIteratorStreamer(TextStreamer):
"""
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
Gradio demo).
<Tip warning={true}>
The API for the streamer classes is still under development and may change in the future.
</Tip>
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
timeout (`float`, *optional*):
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
in `.generate()`, when it is called in a separate thread.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
>>> from threading import Thread
>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextIteratorStreamer(tok)
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
>>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
>>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
>>> thread.start()
>>> generated_text = ""
>>> for new_text in streamer:
... generated_text += new_text
>>> generated_text
'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
```
"""
def __init__(
self, tokenizer: "AutoTokenizer", mp_queue: Queue, skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = mp_queue
self.stop_signal = None
self.timeout = timeout
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
self.text_queue.put(text, timeout=self.timeout)
if stream_end:
self.text_queue.put(self.stop_signal, timeout=self.timeout)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get(timeout=self.timeout)
if value == self.stop_signal:
raise StopIteration()
else:
return value