Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ExecuTorch][BE] Split kv cache and SDPA for better code sharing #7413

Merged
merged 39 commits into from
Jan 23, 2025

Conversation

kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Dec 20, 2024

Stack from ghstack (oldest at bottom):

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now

  1. Decouple SDPA nn.Module from KV cache.
  2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
    both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
    tensors.
  3. 2 will introduce multiple tranposes when KVCache and SDPA are
    replaced by custom modules, but we will write graph pass to undo
    those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: D67914054

Summary:
This enables us to do more easier module swap with model definitions
from torchtune

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Dec 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/7413

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit b6a4eb5 with merge base f4e77c7 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kimishpatel added a commit that referenced this pull request Dec 20, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6356acba83a82cb7d19747187a254a735fa77d28
Pull Request resolved: #7413
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 20, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Dec 20, 2024
Summary:

+ Make all the backend specific kvcache and sdpa implementation abide by
  the new API

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 369434c4d64e6d4500ecfea03b0fd99945b30461
Pull Request resolved: #7413
…for better code sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

[ghstack-poisoned]
…sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

[ghstack-poisoned]
@kimishpatel kimishpatel changed the title Changes to split kv cache and sdpa [ExecuTorch][BE] Split kv cache and SDPA for better code sharing Dec 21, 2024
kimishpatel added a commit that referenced this pull request Dec 21, 2024
Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

ghstack-source-id: 6289ce22a2c190da7e38e098ba8a5d0254d6bf9d
Pull Request resolved: #7413
@kimishpatel kimishpatel requested a review from cccclai December 21, 2024 00:21
@@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager":

return self

def run_canonical_optimizations(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on this function

@@ -47,20 +37,21 @@ def forward(
seqlen,
mask,
):
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought about it again, and adding this transpose here and also before in the llama_transformer.py so that we can share code for kv_cache.py (this is the reason right?) doesn't really make sense since we are using a custom export-friendly KV cache already anyways: https://github.com/pytorch/executorch/blob/main/extension/llm/modules/kv_cache.py#L13

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is really standardizing what the SDPA api is. That is the input tensors q, k and v are in [bs, num heads, seq len, head dim]. This is the standard API. This actually applies to output as well, but that part i havent yet fixed.

So now imagine you did not use kv cache but just wanted to replace SDPA. This allows you to do that.

Regarding KV cache itself, the transpose_cache argument is also relic from the time of using sdpa_with_kv_cache. We dont really need tranpose_cache arg at all. Removing that will bring it closer to tune's kv cache impl. In fact I would like us to not have to use own own kv cache.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I see you are dealing with the paired transposes later anyway. Also extra bit of info is that from my own benchmarking, these transposes here don't seem to be much overhead

@@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager":

return self

def run_canonical_optimizations(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add checks to make sure self.pre_autograd_graph_module is not None, basically this needs to be run after export().

…for better code sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

[ghstack-poisoned]
…sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jan 7, 2025
Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

ghstack-source-id: 05da5d038c624436cdb92009b57cf2b7645ec2b2
Pull Request resolved: #7413
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@jackzhxng jackzhxng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good 👍🏻 pending comments + testing

@@ -47,20 +37,21 @@ def forward(
seqlen,
mask,
):
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I see you are dealing with the paired transposes later anyway. Also extra bit of info is that from my own benchmarking, these transposes here don't seem to be much overhead

return dim_0, dim_1


class RemoveRedundantTransposes(ExportPass):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would name something else, since this is also undoing permutes as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point. I should rename it to permute as transpose is special case of permute as well

graph_module.graph.eliminate_dead_code()
graph_module.recompile()

return PassResult(graph_module, graph_changed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just make this a default pass and run this in to_edge always?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for now I am gonna add it here in next PR. I need to keep this here just to not break any backend delegation stuff for now.

@jackzhxng
Copy link
Contributor

Update: synced offline and wanted to put the notes from the sync somewhere:

  • TorchTune has a focus on server-side finetuning, we will use their model definitions as a starting point for when new models are released to boostrap development
    • As such, updates to TorchTune don't necessarily need to pulled in since usually they will have nothing to do with inference performance or necessarily guarantee exportability. i.e. if there are few new commits on attention, it is not necessary that we update the TorchTune pin in order to maintain parity
  • We will keep separate custom copies of inference-optimized and exportable versions of major bottleneck modules such as attention, sdpa, kv_cache, which we will module swap in
  • While this PR helps overall with code reuse since it lets us source swap in sdpa (although in TorchTune atm the sdpa is a callable), there are no guarantees provided moving forward that the sdpa will be swappable, for example if it moves out of the callable and the bare nn.functional sdpa is in the attention module itself. Then we can't source swap sdpa anymore, it's better to just swap on the attention level.

cc @kimishpatel @tarun292

…for better code sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
…sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jan 13, 2025
Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

ghstack-source-id: 14f5d764f14b33b700ca333ad6d2a1a505858b55
Pull Request resolved: #7413
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@kimishpatel kimishpatel added the module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code label Jan 13, 2025
…for better code sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
…sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
…sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jan 21, 2025
Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

ghstack-source-id: 1079c1d5ae98562c85b832e937fdffaabe1dc575
Pull Request resolved: #7413
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…for better code sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
…sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jan 22, 2025
Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

ghstack-source-id: abaea2cc096952f7bf4d31399dc3aa64e10cfff2
Pull Request resolved: #7413
…for better code sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
…sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jan 22, 2025
Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

ghstack-source-id: a5601e067724f883a1d9a9527f3d165b9dd36060
Pull Request resolved: #7413
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…for better code sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
…sharing"

Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054)

[ghstack-poisoned]
kimishpatel added a commit that referenced this pull request Jan 22, 2025
Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

ghstack-source-id: fbb04f1e148028359803c4bd649c5ea78a378545
Pull Request resolved: #7413
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot facebook-github-bot merged commit 7a59069 into main Jan 23, 2025
73 of 77 checks passed
@facebook-github-bot facebook-github-bot deleted the gh/kimishpatel/149/head branch January 23, 2025 01:51
SS-JIA added a commit that referenced this pull request Jan 30, 2025
…KV cache update operator

## Context

#7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation.

As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators.

Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned.

Differential Revision: [D68916952](https://our.internmc.facebook.com/intern/diff/D68916952/)

[ghstack-poisoned]
SS-JIA added a commit that referenced this pull request Jan 30, 2025
…KV cache update operator

## Context

#7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation.

As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators.

Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned.

Differential Revision: [D68916952](https://our.internmc.facebook.com/intern/diff/D68916952/)

[ghstack-poisoned]
SS-JIA added a commit that referenced this pull request Jan 30, 2025
…KV cache update operator

## Context

#7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation.

As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators.

Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned.

Differential Revision: [D68916952](https://our.internmc.facebook.com/intern/diff/D68916952/)

[ghstack-poisoned]
SS-JIA added a commit that referenced this pull request Jan 30, 2025
…KV cache update operator

## Context

#7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation.

As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators.

Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned.

Differential Revision: [D68919676](https://our.internmc.facebook.com/intern/diff/D68919676/)

[ghstack-poisoned]
SS-JIA added a commit that referenced this pull request Jan 30, 2025
…KV cache update operator

## Context

#7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation.

As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators.

Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned.

Differential Revision: [D68919676](https://our.internmc.facebook.com/intern/diff/D68919676/)

[ghstack-poisoned]
SS-JIA added a commit that referenced this pull request Jan 30, 2025
…KV cache update operator

## Context

#7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation.

As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators.

Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned.

Differential Revision: [D68919676](https://our.internmc.facebook.com/intern/diff/D68919676/)

ghstack-source-id: 263930059
Pull Request resolved: #8068
SS-JIA added a commit that referenced this pull request Jan 30, 2025
…KV cache update operator + Add `RemoveAsserts` pass and apply it during LlaMa export

**Note**: This diff is a combination of D68919676 (#8068) and D68919678 (no pull request). I decided to combine the two because of problems with `ghexport`, which was having some problems exporting the second diff, as well as the fact that both diffs are needed for `export_llama` to work so it makes more sense to just have a single diff.

## Context

#7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation.

As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators.

Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned.
## Context

Recently, some assertion ops were added to the Llama source code.

Unfortunately, this causes issues for the Vulkan delegate because runtime assertions are not yet supported in Vulkan and the assertion ops cause graph breaks due to not being supported.

To prevent graph breaks when delegating to Vulkan, apply a pass to remove assertion ops during the llama export.

Differential Revision: [D68922404](https://our.internmc.facebook.com/intern/diff/D68922404/)

[ghstack-poisoned]
zonglinpeng pushed a commit to zonglinpeng/executorch that referenced this pull request Jan 30, 2025
Differential Revision: D67914054

Pull Request resolved: pytorch#7413
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code topic: not user facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants