From c9acb08c09d79b25058195848b929c0658c9409d Mon Sep 17 00:00:00 2001 From: vTuanpham Date: Thu, 19 Sep 2024 11:46:40 +0700 Subject: [PATCH] fix: fix problem where square brackets occur in the text and got rm --- providers/groq_provider.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/providers/groq_provider.py b/providers/groq_provider.py index 1a804d4..bed3814 100644 --- a/providers/groq_provider.py +++ b/providers/groq_provider.py @@ -74,17 +74,18 @@ def construct_schema_prompt(schema: dict) -> str: return schema_prompt + json_prompt @staticmethod - def remove_brackets(text: str) -> str: + def remove_custom_brackets(text: str) -> str: """ - Remove leading and trailing bracketed expressions from a given text. + Remove leading and trailing custom bracketed expressions from a given text. + Custom brackets are defined as {|[|{ and }|]|}. Args: - text (str): The input string from which bracketed expressions should be removed. + text (str): The input string from which custom bracketed expressions should be removed. Returns: - str: The text with leading and trailing bracketed expressions removed. + str: The text with leading and trailing custom bracketed expressions removed. """ - pattern = r'^\s*\[.*?\]\s*|\s*\[.*?\]\s*$' + pattern = r'^\s*\{\|\[\|\{.*?\}\|\]\|\}\s*|\s*\{\|\[\|\{.*?\}\|\]\|\}\s*$' return re.sub(pattern, '', text, flags=re.DOTALL | re.MULTILINE) @throttle(calls_per_minute=28, verbose=False, break_interval=1200, break_duration=60, jitter=3) @@ -150,8 +151,8 @@ def _do_translate(self, input_data: Union[str, List[str]], else: translated_system_prompt, translated_postfix_prompt = CACHE_INIT_PROMPT[(src, dest)] - prefix_prompt_block = "[START_TRANSLATION_BLOCK]" - postfix_prompt_block = "[END_TRANSLATION_BLOCK]" + prefix_prompt_block = "{|[|{START_TRANSLATION_BLOCK}|]|}" + postfix_prompt_block = "{|[|{END_TRANSLATION_BLOCK}|]|}" prefix_separator = "=" * 10 postfix_separator = "=" * 10 @@ -163,7 +164,6 @@ def _do_translate(self, input_data: Union[str, List[str]], translated_system_prompt += "\n\n" + postfix_system_prompt if postfix_system_prompt else "" translated_prompt = prefix_prompt + "\n\n" + prompt + "\n\n" + postfix_prompt + "\n\n" + translated_postfix_prompt - chat_args = { "messages": [ { @@ -176,8 +176,8 @@ def _do_translate(self, input_data: Union[str, List[str]], } ], "model": "llama3-8b-8192", - "temperature": 0.25, - "top_p": 0.35, + "temperature": 0.3, + "top_p": 0.4, "max_tokens": 8000, "stream": False, } @@ -229,7 +229,7 @@ def _do_translate(self, input_data: Union[str, List[str]], # Clean the translation output if the model repeat the prefix and postfix prompt final_result = final_result.replace(prefix_separator, "").replace(postfix_separator, "") final_result = final_result.replace(prefix_prompt_block, "").replace(postfix_prompt_block, "") - final_result = self.remove_brackets(final_result).strip() + final_result = self.remove_custom_brackets(final_result).strip() try: if data_type == "list":