Skip to content

Commit

Permalink
Fix streaming hieroglyphs (#1492)
Browse files Browse the repository at this point in the history
- When only until n last characters are printed then they are cut in the
middle and we get invalid python utf8 byte sequence and they are
corrupted to �.
- Print until n last tokens! This fixed this issue.
```
visual_language_chat.py ./tiny-random-minicpmv-2_6 ./images <<< $'Describe the images?'
��������������������룅 encouraging룅 encouraging룅 encouraging룅 encouraging룅 encouraging룅 encouraging룅 encouraging
```
After the fix
```
룅튜룅튜룅튜룅튜룅튜룅 encouraging룅 encouraging룅 encouraging룅 encouraging룅 encouraging룅 encouraging
```
CVS-159227

---------

Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
pavel-esir and ilya-lavrenov authored Jan 13, 2025
1 parent 1b46dc0 commit 9cf1601
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 38 deletions.
34 changes: 18 additions & 16 deletions samples/python/multinomial_causal_lm/multinomial_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, tokenizer):
self.tokens_cache = []
self.text_queue = queue.Queue()
self.print_len = 0
self.decoded_lengths = []

def __iter__(self):
"""
Expand Down Expand Up @@ -80,34 +81,35 @@ def put(self, token_id: int) -> bool:
Returns:
bool: True if generation should be stopped, False otherwise.
"""
"""
self.tokens_cache.append(token_id)
text = self.tokenizer.decode(self.tokens_cache)
self.decoded_lengths.append(len(text))

word = ''
delay_n_chars = 4
delay_n_tokens = 3
if len(text) > self.print_len and '\n' == text[-1]:
# Flush the cache after the new line symbol.
word = text[self.print_len:]
word = text[self.print_len:]
self.tokens_cache = []
self.decoded_lengths = []
self.print_len = 0
elif len(text) >= 3 and text[-1] == chr(65533):
elif len(text) > 0 and text[-1] == chr(65533):
# Don't print incomplete text.
pass
elif len(text) > self.print_len + delay_n_chars:
# It is possible to have a shorter text after adding new token.
# Print to output only if text length is increaesed.
# Also, in some cases adding the next token can shorten the text,
# e.g. when apostrophe removing regex had worked after adding new tokens.
# Several last characters are delayed before flushed to output.
word = text[self.print_len:-delay_n_chars]
self.print_len = len(text) - delay_n_chars
self.put_word(word)

self.decoded_lengths[-1] = -1
elif len(self.tokens_cache) >= delay_n_tokens:
print_until = self.decoded_lengths[-delay_n_tokens]
if print_until != -1 and print_until > self.print_len:
# It is possible to have a shorter text after adding new token.
# Print to output only if text length is increased and text is complete (print_until != -1).
word = text[self.print_len:print_until]
self.print_len = print_until
self.put_word(word)

if self.get_stop_flag():
# When generation is stopped from streamer then end is not called, need to call it here manually.
self.end()
return True # True means stop generation
return True # True means stop generation
else:
return False # False means continue generation

Expand Down
37 changes: 23 additions & 14 deletions src/cpp/src/text_callback_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,36 @@ bool TextCallbackStreamer::put(int64_t token) {
std::stringstream res;
m_tokens_cache.push_back(token);
std::string text = m_tokenizer.decode(m_tokens_cache);
m_decoded_lengths.push_back(text.length());

if (!text.empty() && '\n' == text.back() && text.size() > print_len) {
if (!text.empty() && '\n' == text.back() && text.size() > m_printed_len) {
// Flush the cache after the new line symbol
res << std::string_view{text.data() + print_len, text.size() - print_len};
res << std::string_view{text.data() + m_printed_len, text.size() - m_printed_len};
m_tokens_cache.clear();
print_len = 0;
m_decoded_lengths.clear();
m_printed_len = 0;
return on_finalized_subword_callback(res.str());
}

// In some cases adding the next token can shorten the text,
// e.g. when apostrophe removing regex had worked after adding new tokens.
// Several last characters are delayed before flushed to output.
constexpr size_t delay_n_chars = 4;
constexpr size_t delay_n_tokens = 3;
auto print_until = m_decoded_lengths[m_decoded_lengths.size() - delay_n_tokens];
constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error.
if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) {
m_decoded_lengths[m_decoded_lengths.size() - 1] = -1;
// Don't print incomplete text
return on_finalized_subword_callback(res.str());
} else if (text.size() > print_len + delay_n_chars) {
}
// In some cases adding the next token can shorten the text,
// e.g. when apostrophe removing regex had worked after adding new tokens.
// Printing several last tokens is delayed.
if (m_tokens_cache.size() < delay_n_tokens) {
return on_finalized_subword_callback(res.str());
}
if (print_until != -1 && print_until > m_printed_len) {
// It is possible to have a shorter text after adding new token.
// Print to output only if text length is increaesed.
res << std::string_view{text.data() + print_len, text.size() - print_len - delay_n_chars} << std::flush;
print_len = text.size() - delay_n_chars;
res << std::string_view{text.data() + m_printed_len, print_until - m_printed_len} << std::flush;
m_printed_len = print_until;
}

return on_finalized_subword_callback(res.str());
Expand All @@ -45,11 +53,12 @@ bool TextCallbackStreamer::put(int64_t token) {
void TextCallbackStreamer::end() {
std::stringstream res;
std::string text = m_tokenizer.decode(m_tokens_cache);
if (text.size() <= print_len)
return ;
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
if (text.size() <= m_printed_len)
return;
res << std::string_view{text.data() + m_printed_len, text.size() - m_printed_len} << std::flush;
m_tokens_cache.clear();
print_len = 0;
m_decoded_lengths.clear();
m_printed_len = 0;
on_finalized_subword_callback(res.str());
return;
}
Expand Down
4 changes: 3 additions & 1 deletion src/cpp/src/text_callback_streamer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace genai {
class TextCallbackStreamer: public StreamerBase {
public:
bool put(int64_t token) override;

void end() override;

TextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback);
Expand All @@ -20,7 +21,8 @@ class TextCallbackStreamer: public StreamerBase {
private:
Tokenizer m_tokenizer;
std::vector<int64_t> m_tokens_cache;
size_t print_len = 0;
std::vector<int64_t> m_decoded_lengths;
size_t m_printed_len = 0;
};

} // namespace genai
Expand Down
34 changes: 30 additions & 4 deletions tests/python_tests/test_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,25 +361,51 @@ def test_callback_batch_fail(callback):
pipe.generate(['1', '2'], ov_genai.GenerationConfig(), callback)


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
class StremerWithResults:
results: List[str] = []
def __init__(self):
self.results = []

def accumulate(self, subword) -> bool:
self.results.append(subword)
return False

def get_result_str(self) -> str:
return ''.join(self.results)


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword), StremerWithResults()])
@pytest.mark.precommit
@pytest.mark.nightly
def test_callback_kwargs_one_string(callback):
streamer_class = None
if isinstance(callback, StremerWithResults):
streamer_class = callback
callback = callback.accumulate
pipe = read_model(get_models_list()[0])[4]
pipe.generate('table is made of', max_new_tokens=10, streamer=callback)
res = pipe.generate('table is made of', max_new_tokens=10, streamer=callback)
if isinstance(streamer_class, StremerWithResults):
assert res == streamer_class.get_result_str()

@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])

@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword), StremerWithResults()])
@pytest.mark.precommit
@pytest.mark.nightly
@pytest.mark.parametrize("model_descr", get_models_list())
def test_callback_decoding_metallama(model_descr, callback):
streamer_class = None
if isinstance(callback, StremerWithResults):
streamer_class = callback
callback = callback.accumulate
# On metallam this prompt generates output which can shorten after adding new tokens.
# Test that streamer correctly handles such cases.
prompt = 'I have an interview about product speccing with the company Weekend Health. Give me an example of a question they might ask with regards about a new feature'
if model_descr[0] != 'meta-llama/Meta-Llama-3-8B-Instruct':
pytest.skip()
pipe = read_model(model_descr)[4]
pipe.generate(prompt, max_new_tokens=300, streamer=callback)
res = pipe.generate(prompt, max_new_tokens=300, streamer=callback)
if isinstance(streamer_class, StremerWithResults):
assert res == streamer_class.get_result_str()


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
Expand Down
10 changes: 7 additions & 3 deletions tests/python_tests/test_vlm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_ov_model(cache):
@pytest.mark.nightly
def test_vlm_pipeline(cache):
def streamer(word: str) -> bool:
result_from_streamer.append(word)
return False

models_path = get_ov_model(cache)
Expand All @@ -54,14 +55,17 @@ def streamer(word: str) -> bool:
images = []
for link in links:
images.append(get_image_by_link(link))

pipe = VLMPipeline(models_path, "CPU")
pipe.start_chat()

pipe.generate(prompts[0], images=images, generation_config=get_greedy(), streamer=streamer)
result_from_streamer = []
res = pipe.generate(prompts[0], images=images, generation_config=get_greedy(), streamer=streamer)
assert res.texts[0] == ''.join(result_from_streamer)

for prompt in prompts[1:]:
pipe.generate(prompt, generation_config=get_greedy(), streamer=streamer)
result_from_streamer = []
res = pipe.generate(prompt, generation_config=get_greedy(), streamer=streamer)
assert res.texts[0] == ''.join(result_from_streamer)

pipe.finish_chat()

Expand Down

0 comments on commit 9cf1601

Please sign in to comment.