Skip to content

Commit

Permalink
fix: fix problem where square brackets occur in the text and got rm
Browse files Browse the repository at this point in the history
  • Loading branch information
vTuanpham committed Sep 19, 2024
1 parent 22d1696 commit c9acb08
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions providers/groq_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,18 @@ def construct_schema_prompt(schema: dict) -> str:
return schema_prompt + json_prompt

@staticmethod
def remove_brackets(text: str) -> str:
def remove_custom_brackets(text: str) -> str:
"""
Remove leading and trailing bracketed expressions from a given text.
Remove leading and trailing custom bracketed expressions from a given text.
Custom brackets are defined as {|[|{ and }|]|}.
Args:
text (str): The input string from which bracketed expressions should be removed.
text (str): The input string from which custom bracketed expressions should be removed.
Returns:
str: The text with leading and trailing bracketed expressions removed.
str: The text with leading and trailing custom bracketed expressions removed.
"""
pattern = r'^\s*\[.*?\]\s*|\s*\[.*?\]\s*$'
pattern = r'^\s*\{\|\[\|\{.*?\}\|\]\|\}\s*|\s*\{\|\[\|\{.*?\}\|\]\|\}\s*$'
return re.sub(pattern, '', text, flags=re.DOTALL | re.MULTILINE)

@throttle(calls_per_minute=28, verbose=False, break_interval=1200, break_duration=60, jitter=3)
Expand Down Expand Up @@ -150,8 +151,8 @@ def _do_translate(self, input_data: Union[str, List[str]],
else:
translated_system_prompt, translated_postfix_prompt = CACHE_INIT_PROMPT[(src, dest)]

prefix_prompt_block = "[START_TRANSLATION_BLOCK]"
postfix_prompt_block = "[END_TRANSLATION_BLOCK]"
prefix_prompt_block = "{|[|{START_TRANSLATION_BLOCK}|]|}"
postfix_prompt_block = "{|[|{END_TRANSLATION_BLOCK}|]|}"
prefix_separator = "=" * 10
postfix_separator = "=" * 10

Expand All @@ -163,7 +164,6 @@ def _do_translate(self, input_data: Union[str, List[str]],
translated_system_prompt += "\n\n" + postfix_system_prompt if postfix_system_prompt else ""
translated_prompt = prefix_prompt + "\n\n" + prompt + "\n\n" + postfix_prompt + "\n\n" + translated_postfix_prompt


chat_args = {
"messages": [
{
Expand All @@ -176,8 +176,8 @@ def _do_translate(self, input_data: Union[str, List[str]],
}
],
"model": "llama3-8b-8192",
"temperature": 0.25,
"top_p": 0.35,
"temperature": 0.3,
"top_p": 0.4,
"max_tokens": 8000,
"stream": False,
}
Expand Down Expand Up @@ -229,7 +229,7 @@ def _do_translate(self, input_data: Union[str, List[str]],
# Clean the translation output if the model repeat the prefix and postfix prompt
final_result = final_result.replace(prefix_separator, "").replace(postfix_separator, "")
final_result = final_result.replace(prefix_prompt_block, "").replace(postfix_prompt_block, "")
final_result = self.remove_brackets(final_result).strip()
final_result = self.remove_custom_brackets(final_result).strip()

try:
if data_type == "list":
Expand Down

0 comments on commit c9acb08

Please sign in to comment.