Skip to content

Commit

Permalink
For chat template case, re-order limiting of prompt so limit things b…
Browse files Browse the repository at this point in the history
…efore chat template used
  • Loading branch information
pseudotensor committed Mar 30, 2024
1 parent c00f15a commit c18b652
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5841,13 +5841,6 @@ def get_limited_prompt(instruction,
if is_gradio_vision_model(base_model):
use_chat_template = False

if use_chat_template:
context2 = apply_chat_template(instruction, system_prompt, history, tokenizer)
iinput = ''
context = ''
else:
context2 = history_to_context_func(history)

context1 = context
if context1 is None:
context1 = ''
Expand All @@ -5872,17 +5865,27 @@ def get_limited_prompt(instruction,

context1, num_context1_tokens = H2OTextGenerationPipeline.limit_prompt(context1, tokenizer,
max_prompt_length=max_input_tokens)
context2_trial, num_context2_tokens = H2OTextGenerationPipeline.limit_prompt(context2, tokenizer,
max_prompt_length=max_input_tokens)
if not use_chat_template:
context2 = context2_trial

iinput, num_iinput_tokens = H2OTextGenerationPipeline.limit_prompt(iinput, tokenizer,
max_prompt_length=max_input_tokens)
# leave bit for instruction regardless of system prompt
system_prompt, num_system_tokens = H2OTextGenerationPipeline.limit_prompt(system_prompt, tokenizer,
max_prompt_length=int(
max_input_tokens * 0.9))
if use_chat_template:
context2 = apply_chat_template(instruction, system_prompt, history, tokenizer)
iinput = ''
context = ''
else:
context2 = history_to_context_func(history)

context2_trial, num_context2_tokens = H2OTextGenerationPipeline.limit_prompt(context2, tokenizer,
max_prompt_length=max_input_tokens)
if not use_chat_template:
context2 = context2_trial
else:
num_context2_tokens = 0

# limit system prompt
if prompter:
prompter.system_prompt = system_prompt
Expand Down Expand Up @@ -5951,6 +5954,8 @@ def get_limited_prompt(instruction,
history_to_use = history[0 + chat_index:]

if use_chat_template:
instruction, _ = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer,
max_prompt_length=non_doc_max_length)
context2 = apply_chat_template(instruction, system_prompt, history_to_use, tokenizer)
else:
context2 = history_to_context_func(history_to_use)
Expand Down

0 comments on commit c18b652

Please sign in to comment.