Skip to content

Commit 9bd8a53

Browse files
committed
cleanup
1 parent 9d0e19a commit 9bd8a53

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

sharktank/sharktank/layers/kv_cache.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,27 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]:
115115
116116
Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim]
117117
"""
118-
shards = [[torch.empty(
119-
[bs, self.seq_length, self.attn_head_count // self.shard_count, self.attn_head_dim],
118+
allocations = [
119+
torch.empty(
120+
[
121+
bs,
122+
self.seq_length,
123+
self.attn_head_count,
124+
self.attn_head_dim,
125+
],
120126
dtype=self.dtype,
121127
device=self.device,
122-
) for i in range(self.shard_count)]
128+
)
123129
for _ in range(2 * self.transformer_block_count)
124130
]
125131

126132
if self.shard_count == 1:
127-
return [shard[0] for shard in shards]
128-
129-
return [SplitPrimitiveTensor(ts=shrds, shard_dim=2) for shrds in shards]
133+
return allocations
130134

135+
return [
136+
ops.reshard_split(allocation, dim=2, count=self.shard_count)
137+
for allocation in allocations
138+
]
131139

132140
def read(
133141
self,
@@ -156,7 +164,9 @@ def read(
156164
read_count = len(read_into_partitions)
157165
reads = []
158166
for i in range(read_count):
159-
reads.append(state[transformer_block_index * read_count + i][:, :seq_len, :, :])
167+
reads.append(
168+
state[transformer_block_index * read_count + i][:, :seq_len, :, :]
169+
)
160170

161171
return tuple(reads)
162172

0 commit comments

Comments
 (0)