Skip to content

Commit

Permalink
Add option --prompt_index (openvinotoolkit#481)
Browse files Browse the repository at this point in the history
Run the corresponding prompt according to the option prompt index
  • Loading branch information
wgzintel authored Jun 9, 2024
1 parent 9902928 commit 1ee4f38
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 24 deletions.
78 changes: 54 additions & 24 deletions llm_bench/python/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,27 +290,36 @@ def run_text_generation_benchmark(model_path, framework, device, args, num_iters
iter_data_list = []
warmup_md5 = {}
input_text_list = utils.model_utils.get_prompts(args)
text_gen_fn = run_text_generation if not use_genai else run_text_generation_genai
if args['prompt_index'] is None:
prompt_idx_list = [prompt_idx for prompt_idx, input_text in enumerate(input_text_list)]
text_list = input_text_list
else:
prompt_idx_list = []
text_list = []
for i in args['prompt_index']:
if 0 <= i < len(input_text_list):
text_list.append(input_text_list[i])
prompt_idx_list.append(i)
if len(input_text_list) == 0:
raise RuntimeError('==Failure prompts is empty ==')
log.info(f"Numbeams: {args['num_beams']}, benchmarking iter nums(exclude warm-up): {num_iters}, "
f'prompt nums: {len(input_text_list)}')
f'prompt nums: {len(text_list)}, prompt idx: {prompt_idx_list}')

# if num_iters == 0, just output warm-up data
text_gen_fn = run_text_generation if not use_genai else run_text_generation_genai
proc_id = os.getpid()
prompt_idx_list = [prompt_idx for prompt_idx, input_text in enumerate(input_text_list)]
if args['subsequent'] is False:
for num in range(num_iters + 1):
for prompt_idx, input_text in enumerate(input_text_list):
for idx, input_text in enumerate(text_list):
if num == 0:
log.info(f'[warm-up] Input text: {input_text}')
text_gen_fn(input_text, num, model, tokenizer, args, iter_data_list, warmup_md5, prompt_idx, bench_hook, model_precision, proc_id)
text_gen_fn(input_text, num, model, tokenizer, args, iter_data_list, warmup_md5, prompt_idx_list[idx], bench_hook, model_precision, proc_id)
else:
for prompt_idx, input_text in enumerate(input_text_list):
for idx, input_text in enumerate(text_list):
for num in range(num_iters + 1):
if num == 0:
log.info(f'[warm-up] Input text: {input_text}')
text_gen_fn(input_text, num, model, tokenizer, args, iter_data_list, warmup_md5, prompt_idx, bench_hook, model_precision, proc_id)
text_gen_fn(input_text, num, model, tokenizer, args, iter_data_list, warmup_md5, prompt_idx_list[idx], bench_hook, model_precision, proc_id)

utils.metrics_print.print_average(iter_data_list, prompt_idx_list, args['batch_size'], True)
return iter_data_list, pretrain_time
Expand Down Expand Up @@ -383,27 +392,35 @@ def run_image_generation_benchmark(model_path, framework, device, args, num_iter
pipe, pretrain_time = FW_UTILS[framework].create_image_gen_model(model_path, device, **args)
iter_data_list = []
input_image_list = utils.model_utils.get_image_param_from_prompt_file(args)
if len(input_image_list) == 0:
raise RuntimeError('==Failure prompts is empty ==')

if framework == "ov":
stable_diffusion_hook.new_text_encoder(pipe)
stable_diffusion_hook.new_unet(pipe)
stable_diffusion_hook.new_vae_decoder(pipe)

log.info(f'Benchmarking iter nums(exclude warm-up): {num_iters}, prompt nums: {len(input_image_list)}')
if args['prompt_index'] is None:
prompt_idx_list = [image_id for image_id, input_text in enumerate(input_image_list)]
image_list = input_image_list
else:
prompt_idx_list = []
image_list = []
for i in args['prompt_index']:
if 0 <= i < len(input_image_list):
image_list.append(input_image_list[i])
prompt_idx_list.append(i)
if len(image_list) == 0:
raise RuntimeError('==Failure prompts is empty ==')
log.info(f'Benchmarking iter nums(exclude warm-up): {num_iters}, prompt nums: {len(image_list)}, prompt idx: {prompt_idx_list}')

# if num_iters == 0, just output warm-up data
proc_id = os.getpid()
prompt_idx_list = [image_id for image_id, image_param in enumerate(input_image_list)]
if args['subsequent'] is False:
for num in range(num_iters + 1):
for image_id, image_param in enumerate(input_image_list):
run_image_generation(image_param, num, image_id, pipe, args, iter_data_list, proc_id)
for image_id, image_param in enumerate(image_list):
run_image_generation(image_param, num, prompt_idx_list[image_id], pipe, args, iter_data_list, proc_id)
else:
for image_id, image_param in enumerate(input_image_list):
for image_id, image_param in enumerate(image_list):
for num in range(num_iters + 1):
run_image_generation(image_param, num, image_id, pipe, args, iter_data_list, proc_id)
run_image_generation(image_param, num, prompt_idx_list[image_id], pipe, args, iter_data_list, proc_id)

utils.metrics_print.print_average(iter_data_list, prompt_idx_list, args['batch_size'], False)
return iter_data_list, pretrain_time
Expand Down Expand Up @@ -506,21 +523,32 @@ def run_ldm_super_resolution_benchmark(model_path, framework, device, args, num_
images = [images]
else:
raise RuntimeError('==Failure image is empty ==')
log.info(f'Benchmarking iter nums(exclude warm-up): {num_iters}, prompt nums: {len(images)}')

prompt_idx_list = [image_id for image_id, image_param in enumerate(images)]
if args['prompt_index'] is None:
prompt_idx_list = [image_id for image_id, input_text in enumerate(images)]
image_list = images
else:
prompt_idx_list = []
image_list = []
for i in args['prompt_index']:
if 0 <= i < len(images):
image_list.append(images[i])
prompt_idx_list.append(i)
if len(image_list) == 0:
raise RuntimeError('==Failure prompts is empty ==')
log.info(f'Benchmarking iter nums(exclude warm-up): {num_iters}, prompt nums: {len(image_list)}, prompt idx: {prompt_idx_list}')

# if num_iters == 0, just output warm-up data
proc_id = os.getpid()
prompt_idx_list = [image_id for image_id, image_param in enumerate(images)]
for num in range(num_iters + 1):
image_id = 0
for img in images:
for image_id, img in enumerate(image_list):
if num == 0:
if args["output_dir"] is not None:
utils.output_file.output_image_input_text(str(img['prompt']), args, image_id, None, proc_id)
log.info(f"[{'warm-up' if num == 0 else num}] Input image={img['prompt']}")
run_ldm_super_resolution(img, num, pipe, args, framework, iter_data_list, image_id, tm_list, proc_id)
utils.output_file.output_image_input_text(str(img['prompt']), args, prompt_idx_list[image_id], None, proc_id)
log.info(f"[{'warm-up' if num == 0 else num}] Input image={img['prompt']}")
run_ldm_super_resolution(img, num, pipe, args, framework, iter_data_list, prompt_idx_list[image_id], tm_list, proc_id)
tm_list.clear()
image_id = image_id + 1
utils.metrics_print.print_average(iter_data_list, prompt_idx_list, 1, False)

return iter_data_list, pretrain_time
Expand Down Expand Up @@ -549,6 +577,8 @@ def get_argprser():
parser.add_argument('-f', '--framework', default='ov', help='framework')
parser.add_argument('-p', '--prompt', default=None, help='one prompt')
parser.add_argument('-pf', '--prompt_file', default=None, help='prompt file in jsonl format')
parser.add_argument('-pi', '--prompt_index', nargs='+', type=num_iters_type, default=None,
help='Run the specified prompt index. You can specify multiple prompt indexes, separated by spaces.')
parser.add_argument(
'-ic',
'--infer_count',
Expand Down
4 changes: 4 additions & 0 deletions llm_bench/python/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def analyze_args(args):
model_args['subsequent'] = args.subsequent
model_args['output_dir'] = args.output_dir
model_args['genai'] = args.genai
model_args['prompt_index'] = [] if args.prompt_index is not None else None
if model_args['prompt_index'] is not None:
# Deduplication
[model_args['prompt_index'].append(i) for i in args.prompt_index if i not in model_args['prompt_index']]

model_framework = args.framework
model_path = Path(args.model)
Expand Down

0 comments on commit 1ee4f38

Please sign in to comment.