Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama] Remove inplace read for KVCache #849

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 1 addition & 35 deletions sharktank/sharktank/export_layer/export_paged_attention.py
Original file line number Diff line number Diff line change
@@ -37,8 +37,6 @@ def paged_attention(
start_positions: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_state: list[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):

bs, batch_seq_len, _, _ = xq.shape
@@ -54,8 +52,6 @@ def paged_attention(
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
elif attention_block.cache.is_direct:
xk, xv = attention_block.transact_cache_direct(
@@ -112,35 +108,7 @@ def run_llama(
start_positions: Optional[torch.Tensor] = None,
):

if phase == "decode":
bs, _, _, _ = xq.shape

# Allocate per-block temporary K/V tensors. These temporaries hold
# one block's K/V state for the maximum context length.
xk_temp = torch.empty(
[
bs,
config.hp.context_length,
config.hp.attention_head_count_kv,
config.hp.attn_head_dim,
],
dtype=config.activation_dtype,
device=config.device,
)
xv_temp = torch.empty(
[
bs,
config.hp.context_length,
config.hp.attention_head_count_kv,
config.hp.attn_head_dim,
],
dtype=config.activation_dtype,
device=config.device,
)
elif phase == "prefill":
xk_temp = None
xv_temp = None
else:
if phase not in ["prefill", "decode"]:
raise ValueError("'phase' argument needs to be either 'prefill' or 'decode'")

h = paged_attention(
@@ -153,8 +121,6 @@ def run_llama(
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
)

return h
31 changes: 11 additions & 20 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
@@ -158,25 +158,22 @@ def read(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
seq_len: int,
page_ids: Union[torch.Tensor, ReplicatedTensor],
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
):
"""Reads cache partitions from the page table for the given page_ids.
"""Reads K/V caches the page table for the given page_ids.

Args:
state: State struct as returned from allocate().
read_into_partitions: List of cache partitions to read into in-place.
transformer_block_index: The index of the transformer block accessing
the cache.
page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids
to access.

Returns a tuple of cache partitions (i.e. k and v caches for the transformer
block), linearized. Note that this reference approach to reading by
materializing linearly may not be terribly efficient unless if the
compiler can fuse the gather.
Returns the K/V cache partitions, linearized. Note that this reference
approach to reading by materializing linearly may not be terribly
efficient unless if the compiler can fuse the gather.
"""
page_table = self.unflatten_page_table(state) # 6D

@@ -204,32 +201,26 @@ def read(
transformer_block_index * transformer_block_stride
)

def read_cache_partition(
index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor]
):
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
)
def read_cache_partition(index: int):
subblock_ids = base_subblock_ids + index
# TODO: Potentially clamp all page 0 indices to the mask value.
# Or even better, require that the ids are replicated such that access is
# legal.
# Now for each of the k/v attn_block_ids, which have been adjusted to
# index into the sub-pages, we flatten to do a linear index_select
# copy of the sub-blocks by collapsing the first two dims so we have
# a linear list.
# TODO: Can be rewritten into inplace with out= on index_select.
selected = (
ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
.unflatten(0, blocked_shape[0:2])
.flatten(1, 2)
)
# trace_tensor("kv.selected", selected)
into_partition[...] = selected
return selected

for index, read_into_partition in enumerate(read_into_partitions):
read_cache_partition(index, read_into_partition)
key = read_cache_partition(0)
value = read_cache_partition(1)

return tuple([p[:, :seq_len, :] for p in read_into_partitions])
return key[:, :seq_len], value[:, :seq_len]

def write_timestep(
self,
11 changes: 0 additions & 11 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
@@ -98,8 +98,6 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
embedding_batch_mask: Optional[torch.Tensor] = None,
cache_state: list[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):
assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None)

@@ -158,8 +156,6 @@ def forward(
kv_seq_len=kv_seq_len,
start_positions=start_positions,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)

# Expand kv heads for GQA.
@@ -245,8 +241,6 @@ def transact_cache(
seq_block_ids: Optional[torch.Tensor],
kv_seq_len: int,
start_positions: Optional[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):
cache = self.cache
# Manage the cache.
@@ -266,7 +260,6 @@ def transact_cache(
# use a memory efficient attention kernel that can do indirect
# reads, skipping this materialization. This path is taken for
# a decode step.
assert xk_temp is not None and xv_temp is not None
assert xk_cache_update.shape[1] == 1
assert xv_cache_update.shape[1] == 1
assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride
@@ -286,10 +279,6 @@ def transact_cache(
# Restore from the cache.
xk, xv = cache.read(
cache_state,
read_into_partitions=[
xk_temp[:, 0:kv_seq_len, ...],
xv_temp[:, 0:kv_seq_len, ...],
],
transformer_block_index=self.block_index,
page_ids=seq_block_ids,
seq_len=kv_seq_len,
26 changes: 0 additions & 26 deletions sharktank/sharktank/models/grok/grok.py
Original file line number Diff line number Diff line change
@@ -170,37 +170,13 @@ def decode(
self._assert_device(attention_mask, dtype=self.activation_dtype)
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)
bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
embedding_batch_mask = self.attention_embedding.compute_batch_mask(
start_positions, batch_seq_len=1
)
self.trace_tensor("grok.embedding_batch_mask", embedding_batch_mask)

# Allocate per-block temporary K/V tensors. These temporaries hold
# one block's K/V state for the maximum context length.
xk_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)
xv_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)

h = self.token_embedding(tokens)
h *= math.sqrt(h.shape[-1])
self.trace_tensor("grok.token_embedding", h)
@@ -220,8 +196,6 @@ def decode(
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
self.trace_tensor(f"grok.attn_block.{block_idx}.output", h)

52 changes: 0 additions & 52 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
@@ -186,59 +186,13 @@ def decode(
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)

bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
embedding_batch_mask = self.attention_embedding.compute_batch_mask(
start_positions, batch_seq_len=1
)
self.trace_tensor("llama.embedding_batch_mask", embedding_batch_mask)

# Allocate per-block temporary K/V tensors. These temporaries hold
# one block's K/V state for the maximum context length.
if self.config.tensor_parallelism_size == 1:
xk_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)
xv_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)
else:
shard_size = [
bs,
self.context_length,
self.hp.attention_head_count_kv // self.config.tensor_parallelism_size,
self.hp.attn_head_dim,
]
xk_temp_shard = [
torch.empty(
shard_size, dtype=self.config.activation_dtype, device=self.device
)
for _ in range(self.config.tensor_parallelism_size)
]
xv_temp_shard = [
torch.empty(
shard_size, dtype=self.config.activation_dtype, device=self.device
)
for _ in range(self.config.tensor_parallelism_size)
]
xk_temp = SplitPrimitiveTensor(ts=xk_temp_shard, shard_dim=2)
xv_temp = SplitPrimitiveTensor(ts=xv_temp_shard, shard_dim=2)

h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)

@@ -254,8 +208,6 @@ def decode(
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
self.trace_tensor(f"llama.attn_block.{block_idx}.output", h)

@@ -323,8 +275,6 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
embedding_batch_mask: Optional[torch.Tensor] = None,
cache_state: list[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):
h = self.attn(
h,
@@ -335,8 +285,6 @@ def forward(
attention_mask=attention_mask,
embedding_batch_mask=embedding_batch_mask,
cache_state=cache_state,
xk_temp=xk_temp,
xv_temp=xv_temp,
)

# Feed forward network.
25 changes: 0 additions & 25 deletions sharktank/sharktank/models/mixtral/mixtral.py
Original file line number Diff line number Diff line change
@@ -177,29 +177,6 @@ def decode(
)
self.trace_tensor("mixtral.embedding_batch_mask", embedding_batch_mask)

# Allocate per-block temporary K/V tensors. These temporaries hold
# one block's K/V state for the maximum context length.
xk_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)
xv_temp = torch.empty(
[
bs,
self.context_length,
self.hp.attention_head_count_kv,
self.hp.attn_head_dim,
],
dtype=self.config.activation_dtype,
device=self.device,
)

h = self.token_embedding(tokens)
self.trace_tensor("mixtral.token_embedding", h)

@@ -218,8 +195,6 @@ def decode(
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h)

Loading
Loading