From e82bd92e8ec0a81560076f7345ca6a9c65f7e585 Mon Sep 17 00:00:00 2001 From: JP Date: Thu, 23 May 2024 07:39:22 +0000 Subject: [PATCH] standardize prompt formatting for vllm --- eval/vllm_runner.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/eval/vllm_runner.py b/eval/vllm_runner.py index ea09537..7b542bc 100644 --- a/eval/vllm_runner.py +++ b/eval/vllm_runner.py @@ -121,8 +121,19 @@ def chunk_dataframe(df, chunk_size): for batch in (pbar := tqdm(df_chunks, total=len(df))): prompts = batch["prompt"].tolist() print(f"Generating completions for {len(prompts)} prompts") + prompt_tokens = [] + prompt_token_sizes = [] + for prompt in prompts: + token_ids = tokenizer.encode(prompt, add_special_tokens=False) + # add bos token if not already present in prompt + if token_ids[0] != tokenizer.bos_token_id: + token_ids = [tokenizer.bos_token_id] + token_ids + prompt_tokens.append(token_ids) + prompt_token_sizes.append(len(token_ids)) + print(f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}") start_time = time.time() - outputs = llm.generate(prompts, sampling_params) + # outputs = llm.generate(prompts, sampling_params) # if you prefer to use prompts instead of token_ids + outputs = llm.generate(sampling_params=sampling_params, prompt_token_ids=prompt_tokens) print( f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds" )