Skip to content

Commit 8218198

Browse files
Add block_seq_stride param to perplexity
1 parent d1980c7 commit 8218198

File tree

4 files changed

+32
-24
lines changed

4 files changed

+32
-24
lines changed

sharktank/sharktank/evaluate/perplexity_iree.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,14 @@ def __init__(
6868
kv_cache_type,
6969
tensor_parallelism_size,
7070
attention_kernel,
71+
block_seq_stride,
7172
):
7273
self.torch_device = torch_device
7374
self.iree_device = iree_device
7475
self.iree_hip_target = iree_hip_target
7576
self.iree_hal_target_backends = iree_hal_target_backends
7677
self.kv_cache_type = kv_cache_type
78+
self.block_seq_stride = block_seq_stride
7779
self.activation_dtype = torch.float16
7880
self.attention_dtype = torch.float16
7981
self.tensor_parallelism_size = tensor_parallelism_size
@@ -136,6 +138,7 @@ def compile_model(self, weight_path_str):
136138
iree_hal_target_backends=self.iree_hal_target_backends,
137139
attention_kernel=self.attention_kernel,
138140
tensor_parallelism_size=self.tensor_parallelism_size,
141+
block_seq_stride=self.block_seq_stride,
139142
)
140143
vmfb_path = export_artifacts.get_artifacts()
141144
return vmfb_path
@@ -145,7 +148,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path):
145148

146149
self.config = LlamaModelConfig(
147150
hp=configs.LlamaHParams.from_gguf_props(weight_path.properties),
148-
block_seq_stride=16,
151+
block_seq_stride=self.block_seq_stride,
149152
kv_cache_type=self.kv_cache_type,
150153
device=self.torch_device,
151154
activation_dtype=self.activation_dtype,
@@ -394,6 +397,7 @@ def run_perplexity(
394397
tensor_parallelism_size,
395398
attention_kernel,
396399
num_prompts,
400+
block_seq_stride,
397401
):
398402
start = time.time()
399403
perplexity = Perplexity(
@@ -404,6 +408,7 @@ def run_perplexity(
404408
kv_cache_type=kv_cache_type,
405409
tensor_parallelism_size=tensor_parallelism_size,
406410
attention_kernel=attention_kernel,
411+
block_seq_stride=block_seq_stride,
407412
)
408413

409414
perplexity.get_prompts(num_prompts=num_prompts)
@@ -425,8 +430,18 @@ def run_perplexity(
425430

426431
def main(argv):
427432
parser = cli.create_parser()
428-
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
429-
parser.add_argument("--torch-device", help="Torch device (or default)")
433+
parser.add_argument(
434+
"--attention-kernel",
435+
type=str,
436+
default="decomposed",
437+
choices=["decomposed", "torch_sdpa"],
438+
)
439+
parser.add_argument(
440+
"--block-seq-stride",
441+
help="Block sequence stride for paged KV cache, must divide evenly into the context length",
442+
type=int,
443+
default="32",
444+
)
430445
parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')")
431446
parser.add_argument(
432447
"--iree-hip-target",
@@ -440,48 +455,42 @@ def main(argv):
440455
default="rocm",
441456
help="Specify the iree-hal target backends (e.g., rocm)",
442457
)
458+
parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
443459
parser.add_argument(
444-
"--attention-kernel",
445-
type=str,
446-
default="decomposed",
447-
choices=["decomposed", "torch_sdpa"],
460+
"--num-prompts",
461+
type=int,
462+
default=100,
463+
help="Number of prompts for perplexity test (1 to 100)",
448464
)
449465
parser.add_argument(
450466
"--tensor-parallelism-size",
451467
type=int,
452468
default=1,
453469
help="Number of devices for tensor parallel sharding",
454470
)
455-
parser.add_argument(
456-
"--num-prompts",
457-
type=int,
458-
default=100,
459-
help="Number of prompts for perplexity test",
460-
)
471+
parser.add_argument("--torch-device", help="Torch device (or default)")
461472

462473
cli.add_tokenizer_options(parser)
463474
cli.add_input_dataset_options(parser)
464475
args = cli.parse(parser, args=argv)
465476

466477
torch_device = torch.device(args.torch_device) if args.torch_device else None
467-
iree_device = args.iree_device
468-
kv_cache_type = args.kv_cache_type
469478
weight_path = cli.get_input_dataset(args)
470479
tokenizer = cli.get_tokenizer(args)
471-
weight_path_str = str(args.irpa_file)
472480

473481
ppl = run_perplexity(
474482
weight_path=weight_path,
475-
weight_path_str=weight_path_str,
483+
weight_path_str=str(args.irpa_file),
476484
tokenizer=tokenizer,
477485
torch_device=torch_device,
478-
iree_device=iree_device,
486+
iree_device=args.iree_device,
479487
iree_hip_target=args.iree_hip_target,
480488
iree_hal_target_backends=args.iree_hal_target_backends,
481-
kv_cache_type=kv_cache_type,
489+
kv_cache_type=args.kv_cache_type,
482490
tensor_parallelism_size=args.tensor_parallelism_size,
483491
attention_kernel=args.attention_kernel,
484492
num_prompts=args.num_prompts,
493+
block_seq_stride=args.block_seq_stride,
485494
)
486495

487496
logger.info(f"\n{json.dumps(ppl, indent=2)}")

sharktank/sharktank/examples/export_paged_llm_v1.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def main():
4949
"--block-seq-stride",
5050
help="Block sequence stride for paged KV cache, must divide evenly into the context length",
5151
type=int,
52-
default="16",
52+
default="32",
5353
)
5454
parser.add_argument(
5555
"--verbose",

sharktank/sharktank/layers/configs/llm_configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class LlamaModelConfig:
144144

145145
# Block sequence stride for a paged KV cache. This must divide evenly
146146
# into the context length.
147-
block_seq_stride: int = 16
147+
block_seq_stride: int = 32
148148

149149
# Either "paged" or "direct".
150150
kv_cache_type: str = "paged"

sharktank/sharktank/utils/export_artifacts.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(
9292
iree_hal_target_backends: str,
9393
attention_kernel: str,
9494
tensor_parallelism_size: int,
95-
block_seq_stride: Optional[int] = None,
95+
block_seq_stride: int,
9696
):
9797
self.sharktank_dir = str(
9898
Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent
@@ -180,14 +180,13 @@ def export_to_mlir(
180180
f"--output-mlir={mlir_path}",
181181
f"--output-config={json_path}",
182182
f"--bs={str(self.batch_size)}",
183+
f"--block-seq-stride={self.block_seq_stride}",
183184
]
184185
if skip_decode:
185186
export_args.append("--skip-decode")
186187
if self.attention_kernel in ["decomposed", "torch"]:
187188
export_args.append("--attention-kernel")
188189
export_args.append(self.attention_kernel)
189-
if self.block_seq_stride:
190-
export_args.append(f"--block-seq-stride={self.block_seq_stride}")
191190

192191
cwd = self.sharktank_dir
193192
cmd = subprocess.list2cmdline(export_args)

0 commit comments

Comments
 (0)