From c8af5a1dfb1849f4ffb5a29409fdd07f90990a1a Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 21 Nov 2024 08:38:14 -0800 Subject: [PATCH] In paged LLM exporting do not add any argument device affinities when not sharding (#576) It is illegal to have ops with some arguments having affinties and some not. Operands will be considered on different devices and this is not allowed right now. With recent enough version of IREE during compilation this will result in ``` failure.mlir:1:1: error: affinity analysis failed to converge (input program may have invalid affinities assigned); use`--iree-stream-annotate-input-affinities` to help identify the invalid affinities ``` This change makes it so that we do not add any affinities when not sharding. --- .../sharktank/examples/export_paged_llm_v1.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 791bce87c..6dd9785c3 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -167,11 +167,12 @@ def generate_batch_prefill(bs: int): model, llama_config.tensor_parallelism_size ) - # We need to offset the indices for the cache - arg_affinities = {key + 3: arg_affinities[key] for key in arg_affinities} + if llama_config.tensor_parallelism_size > 1: + # We need to offset the indices for the cache + arg_affinities = {key + 3: arg_affinities[key] for key in arg_affinities} - for i in range(3): - arg_affinities[i] = DeviceAffinity("0") + for i in range(3): + arg_affinities[i] = DeviceAffinity("0") dynamic_shapes = { "tokens": {1: sl_dim}, @@ -244,12 +245,13 @@ def generate_batch_decode(bs: int): arg_affinities, ) = setup_cache(model, llama_config.tensor_parallelism_size) - # We need to offset the indices for the cache - arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities} + if llama_config.tensor_parallelism_size > 1: + # We need to offset the indices for the cache + arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities} - # Inputs have default affinity 0 - for i in range(4): - arg_affinities[i] = DeviceAffinity("0") + # Inputs have default affinity 0 + for i in range(4): + arg_affinities[i] = DeviceAffinity("0") dynamic_shapes = { "tokens": {},