Skip to content

Commit

Permalink
Merge pull request #13 from chtanch/qwen-template
Browse files Browse the repository at this point in the history
Add instruction template for Qwen and chatGLM3
  • Loading branch information
sgwhat authored Mar 4, 2024
2 parents c82384d + e05b76c commit 8c67c21
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 11 deletions.
24 changes: 24 additions & 0 deletions instruction-templates/ChatGLM3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
instruction_template: |-
{%- set ns = namespace(found=false) -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- set ns.found = true -%}
{%- endif -%}
{%- endfor -%}
{%- if not ns.found -%}
{{- '' + '' + '' -}}
{%- endif %}
{%- for message in messages %}
{%- if message['role'] == 'system' -%}
{{- '<|system|>\n' + message['content'] + '\n' -}}
{%- else -%}
{%- if message['role'] == 'user' -%}
{{-'<|user|>\n' + message['content'] + '\n'-}}
{%- else -%}
{{-'<|assistant|>\n' + message['content'] + '\n' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{-'<|assistant|>\n'-}}
{%- endif -%}
24 changes: 24 additions & 0 deletions instruction-templates/Qwen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
instruction_template: |-
{%- set ns = namespace(found=false) -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- set ns.found = true -%}
{%- endif -%}
{%- endfor -%}
{%- if not ns.found -%}
{{- '' + '' + '' -}}
{%- endif %}
{%- for message in messages %}
{%- if message['role'] == 'system' -%}
{{- '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' -}}
{%- else -%}
{%- if message['role'] == 'user' -%}
{{-'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'-}}
{%- else -%}
{{-'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{-'<|im_start|>assistant\n'-}}
{%- endif -%}
2 changes: 2 additions & 0 deletions models/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,5 @@
instruction_template: 'ChatML'
.*synthia:
instruction_template: 'Synthia'
.*qwen:
instruction_template: 'Qwen'
22 changes: 11 additions & 11 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,17 +314,17 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
generate_params.update({'synced_gpus': True})

#tune the prompt based on qwen
QWEN_PROMPT_FORMAT = """
<|im_start|>system
You are a helpful assistant.
<|im_end|>
<|im_start|>user
{prompt}
<|im_end|>
<|im_start|>assistant
"""
if shared.model.config.model_type == "qwen":
question = QWEN_PROMPT_FORMAT.format(prompt=question)
# QWEN_PROMPT_FORMAT = """
# <|im_start|>system
# You are a helpful assistant.
# <|im_end|>
# <|im_start|>user
# {prompt}
# <|im_end|>
# <|im_start|>assistant
# """
# if shared.model.config.model_type == "qwen":
# question = QWEN_PROMPT_FORMAT.format(prompt=question)

# Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
Expand Down

0 comments on commit 8c67c21

Please sign in to comment.