@@ -68,12 +68,14 @@ def __init__(
68
68
kv_cache_type ,
69
69
tensor_parallelism_size ,
70
70
attention_kernel ,
71
+ block_seq_stride ,
71
72
):
72
73
self .torch_device = torch_device
73
74
self .iree_device = iree_device
74
75
self .iree_hip_target = iree_hip_target
75
76
self .iree_hal_target_backends = iree_hal_target_backends
76
77
self .kv_cache_type = kv_cache_type
78
+ self .block_seq_stride = block_seq_stride
77
79
self .activation_dtype = torch .float16
78
80
self .attention_dtype = torch .float16
79
81
self .tensor_parallelism_size = tensor_parallelism_size
@@ -136,6 +138,7 @@ def compile_model(self, weight_path_str):
136
138
iree_hal_target_backends = self .iree_hal_target_backends ,
137
139
attention_kernel = self .attention_kernel ,
138
140
tensor_parallelism_size = self .tensor_parallelism_size ,
141
+ block_seq_stride = self .block_seq_stride ,
139
142
)
140
143
vmfb_path = export_artifacts .get_artifacts ()
141
144
return vmfb_path
@@ -145,7 +148,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path):
145
148
146
149
self .config = LlamaModelConfig (
147
150
hp = configs .LlamaHParams .from_gguf_props (weight_path .properties ),
148
- block_seq_stride = 16 ,
151
+ block_seq_stride = self . block_seq_stride ,
149
152
kv_cache_type = self .kv_cache_type ,
150
153
device = self .torch_device ,
151
154
activation_dtype = self .activation_dtype ,
@@ -394,6 +397,7 @@ def run_perplexity(
394
397
tensor_parallelism_size ,
395
398
attention_kernel ,
396
399
num_prompts ,
400
+ block_seq_stride ,
397
401
):
398
402
start = time .time ()
399
403
perplexity = Perplexity (
@@ -404,6 +408,7 @@ def run_perplexity(
404
408
kv_cache_type = kv_cache_type ,
405
409
tensor_parallelism_size = tensor_parallelism_size ,
406
410
attention_kernel = attention_kernel ,
411
+ block_seq_stride = block_seq_stride ,
407
412
)
408
413
409
414
perplexity .get_prompts (num_prompts = num_prompts )
@@ -425,8 +430,18 @@ def run_perplexity(
425
430
426
431
def main (argv ):
427
432
parser = cli .create_parser ()
428
- parser .add_argument ("--kv-cache-type" , default = "paged" , help = "KV cache type" )
429
- parser .add_argument ("--torch-device" , help = "Torch device (or default)" )
433
+ parser .add_argument (
434
+ "--attention-kernel" ,
435
+ type = str ,
436
+ default = "decomposed" ,
437
+ choices = ["decomposed" , "torch_sdpa" ],
438
+ )
439
+ parser .add_argument (
440
+ "--block-seq-stride" ,
441
+ help = "Block sequence stride for paged KV cache, must divide evenly into the context length" ,
442
+ type = int ,
443
+ default = "32" ,
444
+ )
430
445
parser .add_argument ("--iree-device" , help = "List an IREE device (e.g., 'hip://0')" )
431
446
parser .add_argument (
432
447
"--iree-hip-target" ,
@@ -440,48 +455,42 @@ def main(argv):
440
455
default = "rocm" ,
441
456
help = "Specify the iree-hal target backends (e.g., rocm)" ,
442
457
)
458
+ parser .add_argument ("--kv-cache-type" , default = "paged" , help = "KV cache type" )
443
459
parser .add_argument (
444
- "--attention-kernel " ,
445
- type = str ,
446
- default = "decomposed" ,
447
- choices = [ "decomposed" , "torch_sdpa" ] ,
460
+ "--num-prompts " ,
461
+ type = int ,
462
+ default = 100 ,
463
+ help = "Number of prompts for perplexity test (1 to 100)" ,
448
464
)
449
465
parser .add_argument (
450
466
"--tensor-parallelism-size" ,
451
467
type = int ,
452
468
default = 1 ,
453
469
help = "Number of devices for tensor parallel sharding" ,
454
470
)
455
- parser .add_argument (
456
- "--num-prompts" ,
457
- type = int ,
458
- default = 100 ,
459
- help = "Number of prompts for perplexity test" ,
460
- )
471
+ parser .add_argument ("--torch-device" , help = "Torch device (or default)" )
461
472
462
473
cli .add_tokenizer_options (parser )
463
474
cli .add_input_dataset_options (parser )
464
475
args = cli .parse (parser , args = argv )
465
476
466
477
torch_device = torch .device (args .torch_device ) if args .torch_device else None
467
- iree_device = args .iree_device
468
- kv_cache_type = args .kv_cache_type
469
478
weight_path = cli .get_input_dataset (args )
470
479
tokenizer = cli .get_tokenizer (args )
471
- weight_path_str = str (args .irpa_file )
472
480
473
481
ppl = run_perplexity (
474
482
weight_path = weight_path ,
475
- weight_path_str = weight_path_str ,
483
+ weight_path_str = str ( args . irpa_file ) ,
476
484
tokenizer = tokenizer ,
477
485
torch_device = torch_device ,
478
- iree_device = iree_device ,
486
+ iree_device = args . iree_device ,
479
487
iree_hip_target = args .iree_hip_target ,
480
488
iree_hal_target_backends = args .iree_hal_target_backends ,
481
- kv_cache_type = kv_cache_type ,
489
+ kv_cache_type = args . kv_cache_type ,
482
490
tensor_parallelism_size = args .tensor_parallelism_size ,
483
491
attention_kernel = args .attention_kernel ,
484
492
num_prompts = args .num_prompts ,
493
+ block_seq_stride = args .block_seq_stride ,
485
494
)
486
495
487
496
logger .info (f"\n { json .dumps (ppl , indent = 2 )} " )
0 commit comments