@@ -116,35 +116,38 @@ def setup_cache(model, shard_count):
116
116
page_count = hp .context_length // llama_config .block_seq_stride
117
117
)
118
118
page_dim = torch .export .Dim ("page" )
119
+
119
120
dynamic_shapes = [{0 : page_dim }]
121
+ unpacked = cache_state
122
+ arg_affinities = {}
123
+ shard_dim = None
124
+
125
+ # Need to unpacke that state when sharded
126
+ if llama_config .tensor_parallelism_size > 1 :
127
+ shard_dim = cache_state [0 ].shard_dim
128
+
129
+ unpacked = [[shard ._data for shard in cs .shards ] for cs in cache_state ]
130
+ dynamic_shapes = [
131
+ [ds ] * llama_config .tensor_parallelism_size for ds in dynamic_shapes
132
+ ]
133
+
134
+ for i in range (llama_config .tensor_parallelism_size ):
135
+ arg_affinities [i ] = DeviceAffinity (str (i ))
136
+
137
+ return unpacked , shard_dim , dynamic_shapes , arg_affinities
138
+
120
139
elif model .config .kv_cache_type == "direct" :
121
140
cache_state = model .cache .allocate (bs = 1 )
122
141
# Direct cache dimensions:
123
142
# 2 * transformer_block_count of...
124
143
# [bs, seq_length, attn_head_count, attn_head_dim]
125
144
dynamic_shapes = [None ]
145
+ arg_affinities = {}
146
+ shard_dim = None
147
+ return torch .stack (cache_state ), shard_dim , dynamic_shapes , arg_affinities
126
148
else :
127
149
raise NotImplementedError (f"Unsupported KV cache type: { type (model .cache )} " )
128
150
129
- unpacked = cache_state
130
- dynamic_shapes = dynamic_shapes
131
- arg_affinities = {}
132
- shard_dim = None
133
-
134
- # Need to unpacke that state when sharded
135
- if llama_config .tensor_parallelism_size > 1 :
136
- shard_dim = cache_state [0 ].shard_dim
137
-
138
- unpacked = [[shard ._data for shard in cs .shards ] for cs in cache_state ]
139
- dynamic_shapes = [
140
- [ds ] * llama_config .tensor_parallelism_size for ds in dynamic_shapes
141
- ]
142
-
143
- for i in range (llama_config .tensor_parallelism_size ):
144
- arg_affinities [i ] = DeviceAffinity (str (i ))
145
-
146
- return torch .stack (unpacked ), shard_dim , dynamic_shapes , arg_affinities
147
-
148
151
def repack_cache (cache , shard_dim ):
149
152
return [SplitPrimitiveTensor (ts = c , shard_dim = shard_dim ) for c in cache ]
150
153
@@ -184,7 +187,13 @@ def generate_batch_prefill(bs: int):
184
187
arg_device = arg_affinities ,
185
188
)
186
189
def _ (model , tokens , seq_lens , seq_block_ids , cs ):
187
- cache_tensors = torch .unbind (cs )
190
+ if (
191
+ model .config .tensor_parallelism_size == 1
192
+ and model .config .kv_cache_type == "direct"
193
+ ):
194
+ cache_tensors = torch .unbind (cs )
195
+ else :
196
+ cache_tensors = cs
188
197
189
198
sl = tokens .shape [1 ]
190
199
input_mask = model .input_mask (seq_lens , sl )
0 commit comments