Skip to content

Commit e82bd92

Browse files
committed
standardize prompt formatting for vllm
1 parent 16875cf commit e82bd92

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

eval/vllm_runner.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,19 @@ def chunk_dataframe(df, chunk_size):
121121
for batch in (pbar := tqdm(df_chunks, total=len(df))):
122122
prompts = batch["prompt"].tolist()
123123
print(f"Generating completions for {len(prompts)} prompts")
124+
prompt_tokens = []
125+
prompt_token_sizes = []
126+
for prompt in prompts:
127+
token_ids = tokenizer.encode(prompt, add_special_tokens=False)
128+
# add bos token if not already present in prompt
129+
if token_ids[0] != tokenizer.bos_token_id:
130+
token_ids = [tokenizer.bos_token_id] + token_ids
131+
prompt_tokens.append(token_ids)
132+
prompt_token_sizes.append(len(token_ids))
133+
print(f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}")
124134
start_time = time.time()
125-
outputs = llm.generate(prompts, sampling_params)
135+
# outputs = llm.generate(prompts, sampling_params) # if you prefer to use prompts instead of token_ids
136+
outputs = llm.generate(sampling_params=sampling_params, prompt_token_ids=prompt_tokens)
126137
print(
127138
f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds"
128139
)

0 commit comments

Comments
 (0)