Skip to content

Commit

Permalink
In paged LLM exporting do not add any argument device affinities when…
Browse files Browse the repository at this point in the history
… not sharding (#576)

It is illegal to have ops with some arguments having affinties and some
not. Operands will be considered on different devices and this is not
allowed right now.
With recent enough version of IREE during compilation this will result
in

```
failure.mlir:1:1: error: affinity analysis failed to converge (input program may have invalid affinities assigned); use`--iree-stream-annotate-input-affinities` to help identify the invalid affinities
```

This change makes it so that we do not add any affinities when not
sharding.
  • Loading branch information
sogartar authored Nov 21, 2024
1 parent 9535984 commit c8af5a1
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,12 @@ def generate_batch_prefill(bs: int):
model, llama_config.tensor_parallelism_size
)

# We need to offset the indices for the cache
arg_affinities = {key + 3: arg_affinities[key] for key in arg_affinities}
if llama_config.tensor_parallelism_size > 1:
# We need to offset the indices for the cache
arg_affinities = {key + 3: arg_affinities[key] for key in arg_affinities}

for i in range(3):
arg_affinities[i] = DeviceAffinity("0")
for i in range(3):
arg_affinities[i] = DeviceAffinity("0")

dynamic_shapes = {
"tokens": {1: sl_dim},
Expand Down Expand Up @@ -244,12 +245,13 @@ def generate_batch_decode(bs: int):
arg_affinities,
) = setup_cache(model, llama_config.tensor_parallelism_size)

# We need to offset the indices for the cache
arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities}
if llama_config.tensor_parallelism_size > 1:
# We need to offset the indices for the cache
arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities}

# Inputs have default affinity 0
for i in range(4):
arg_affinities[i] = DeviceAffinity("0")
# Inputs have default affinity 0
for i in range(4):
arg_affinities[i] = DeviceAffinity("0")

dynamic_shapes = {
"tokens": {},
Expand Down

0 comments on commit c8af5a1

Please sign in to comment.