Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sentence accumulator #545

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/llm/output/actions_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ def modify_sentence_content(self, cut_content: sentence_content, last_content: s
# if action.is_interrupting:
# settings.stop_generation = True
return cut_content, last_content

def get_cut_indicators(self) -> list[str]:
return [":"]
5 changes: 4 additions & 1 deletion src/llm/output/change_character_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ def cut_sentence(self, output: str, current_settings: sentence_generation_settin
return None, output #There is a ':' in the text, but it doesn't seem to be part of a character change

def modify_sentence_content(self, cut_content: sentence_content, last_content: sentence_content | None, settings: sentence_generation_settings) -> tuple[sentence_content | None, sentence_content | None]:
return cut_content, last_content
return cut_content, last_content

def get_cut_indicators(self) -> list[str]:
return [":"]
3 changes: 3 additions & 0 deletions src/llm/output/narration_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def __init__(self, narration_start_chars: list[str] = ["*","(","["], narration_e
else:
self.__start_speech_reg = never_match_anything_regex
self.__end_speech_reg = never_match_anything_regex

def get_cut_indicators(self) -> list[str]:
return self.__narration_start_chars + self.__narration_end_chars + self.__speech_start_chars + self.__speech_end_chars

def cut_sentence(self, output: str, current_settings: sentence_generation_settings) -> tuple[sentence_content | None, str]:
output = output.lstrip()
Expand Down
3 changes: 3 additions & 0 deletions src/llm/output/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,6 @@ def cut_sentence(self, output: str, current_settings: sentence_generation_settin
@abstractmethod
def modify_sentence_content(self, cut_content: sentence_content, last_content: sentence_content | None, settings: sentence_generation_settings) -> tuple[sentence_content | None, sentence_content | None]:
return cut_content, last_content

def get_cut_indicators(self) -> list[str]:
return []
43 changes: 43 additions & 0 deletions src/llm/output/sentence_accumulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import re


class sentence_accumulator:
"""Accumulates the token-wise output of an LLM into raw sentences.
"""
def __init__(self, cut_indicators: list[str]) -> None:
self.__cut_indicators = cut_indicators
self.__unprocessed_llm_output: str = ""
base_regex_def = "^.*?[{sentence_end_chars}]+"
self.__sentence_end_reg = re.compile(base_regex_def.format(sentence_end_chars = "\\" + "\\".join(cut_indicators)))
self.__unparseable: str = ""
self.__prepared_match: str = ""

def has_next_sentence(self) -> bool:
if len(self.__prepared_match) > 0:
return True

match = self.__sentence_end_reg.match(self.__unprocessed_llm_output)
if not match:
return False
else:
self.__prepared_match = match.group()
self.__unprocessed_llm_output = self.__unprocessed_llm_output.removeprefix(self.__prepared_match)
return True

def get_next_sentence(self) -> str:
result = self.__unparseable + self.__prepared_match
self.__unparseable = ""
self.__prepared_match = ""
return result

def accumulate(self, llm_output: str):
llm_output = llm_output.replace('\r\n', ' ')
llm_output = llm_output.replace('\n', ' ')
self.__unprocessed_llm_output += llm_output

def refuse(self, refused_text: str):
self.__unparseable = refused_text




5 changes: 4 additions & 1 deletion src/llm/output/sentence_end_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ def cut_sentence(self, output: str, current_settings: sentence_generation_settin
return sentence_content(current_settings.current_speaker, matched_text, current_settings.sentence_type, False), rest

def modify_sentence_content(self, cut_content: sentence_content, last_content: sentence_content | None, settings: sentence_generation_settings) -> tuple[sentence_content | None, sentence_content | None]:
return cut_content, last_content
return cut_content, last_content

def get_cut_indicators(self) -> list[str]:
return self.__end_of_sentence_chars
45 changes: 29 additions & 16 deletions src/output_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import unicodedata
from openai import APIConnectionError
from src.llm.output.sentence_accumulator import sentence_accumulator
from src.config.definitions.llm_definitions import NarrationHandlingEnum
from src.llm.output.max_count_sentences_parser import max_count_sentences_parser
from src.llm.output.sentence_length_parser import sentence_length_parser
Expand Down Expand Up @@ -143,6 +144,13 @@ async def process_response(self, active_character: Character, blocking_queue: se
sentence_length_parser(self.__config.number_words_tts),
max_count_sentences_parser(self.__config.max_response_sentences, not characters.contains_player_character())
])

cut_indicators: set[str] = set()
for parser in parser_chain:
indicators = parser.get_cut_indicators()
for i in indicators:
cut_indicators.add(i)
accumulator: sentence_accumulator = sentence_accumulator(list(cut_indicators))

try:
current_sentence: str = ''
Expand All @@ -160,26 +168,31 @@ async def process_response(self, active_character: Character, blocking_queue: se
logging.log(self.loglevel, f"LLM took {round(time.time() - start_time, 5)} seconds to respond")
first_token = False

current_sentence += content
raw_response += content
parsed_sentence: sentence_content | None = None
# Apply parsers
for parser in parser_chain:
if not parsed_sentence: # Try to extract a complete sentence
parsed_sentence, current_sentence = parser.cut_sentence(current_sentence, settings)
if parsed_sentence: # Apply modifications if we already have a sentence
parsed_sentence, pending_sentence = parser.modify_sentence_content(parsed_sentence, pending_sentence, settings)
accumulator.accumulate(content)
while accumulator.has_next_sentence():
current_sentence = accumulator.get_next_sentence()
# current_sentence += content
parsed_sentence: sentence_content | None = None
# Apply parsers
for parser in parser_chain:
if not parsed_sentence: # Try to extract a complete sentence
parsed_sentence, current_sentence = parser.cut_sentence(current_sentence, settings)
if parsed_sentence: # Apply modifications if we already have a sentence
parsed_sentence, pending_sentence = parser.modify_sentence_content(parsed_sentence, pending_sentence, settings)
if settings.stop_generation:
break
if settings.stop_generation:
break
accumulator.refuse(current_sentence)
# Process sentences from the parser chain
if parsed_sentence:
if not self.__config.narration_handling == NarrationHandlingEnum.CUT_NARRATIONS or parsed_sentence.sentence_type != SentenceTypeEnum.NARRATION:
new_sentence = self.generate_sentence(parsed_sentence)
blocking_queue.put(new_sentence)
parsed_sentence = None
if settings.stop_generation:
break

# Process sentences from the parser chain
if parsed_sentence:
if not self.__config.narration_handling == NarrationHandlingEnum.CUT_NARRATIONS or parsed_sentence.sentence_type != SentenceTypeEnum.NARRATION:
new_sentence = self.generate_sentence(parsed_sentence)
blocking_queue.put(new_sentence)
parsed_sentence = None
break
break #if the streaming_call() completed without exception, break the while loop

except Exception as e:
Expand Down