From fd1cd73ee4efc0a99cf50bd448db0b5cd0003d6e Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 30 Jan 2025 13:45:00 -0800 Subject: [PATCH] [ET-VK][LlaMa] Split SDPA + KV cache operator into SDPA operator and KV cache update operator ## Context https://github.com/pytorch/executorch/pull/7413 and https://github.com/pytorch/executorch/pull/7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. Differential Revision: [D68916952](https://our.internmc.facebook.com/intern/diff/D68916952/) [ghstack-poisoned] --- .../vulkan/_passes/insert_prepack_nodes.py | 10 +- .../vulkan/_passes/tag_memory_meta_pass.py | 4 +- backends/vulkan/op_registry.py | 16 ++- .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 100 +++++++++++++----- 4 files changed, 99 insertions(+), 31 deletions(-) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 7876806d6d..bf1fc28ba5 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool: ) # This pass assumes that the SpecPropPass() has already been applied assert "spec" in node.meta + # Mutable buffers will not be marked as constant, but it might as well be + # for the purposes of memory planning. Mark it as a constant tensor so that + # it is handled correctly by the memory planning pass. + if not node.meta["spec"].const: + assert is_param_node(program, node) + node.meta["spec"].const = True # Validate that the original node is marked as a constant. Constant tensors # do not participate in memory planning. assert node.meta["spec"].const @@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool: # Set the mem_obj_id to -1 to indicate that this node requires a dedicated # memory object. prepack_node.meta["spec"].mem_obj_id = -1 - node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y) + node.replace_all_uses_with( + prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output") + ) program.graph.eliminate_dead_code() return program diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 1d08817e26..44e1b9aeca 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -220,9 +220,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool: # noqa def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) - - for node in sorted_nodes: + for node in graph_module.graph.nodes: if not self.should_annotate(node) or self.should_delay_annotation(node): continue diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index d70cf93b88..b8eca82904 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures): @update_features("llama::sdpa_with_kv_cache") -def register_sdpa_op(features: OpFeatures): +def register_sdpa_with_kv_cache_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( valid_packed_dims={PackedDim.WIDTH}, ) @@ -489,6 +489,20 @@ def register_sdpa_op(features: OpFeatures): return features +# TODO(ssjia) allow registration after remove assertions pass is implemented +# @update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default]) +def register_sdpa_ops(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + valid_packed_dims={PackedDim.WIDTH}, + ) + features.resize_fn = False + features.buffer_impl = False + features.texture_impl = TextureImplFeatures( + valid_packed_dims={PackedDim.WIDTH}, + ) + return features + + @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 2c46201351..6dcf2fc4f4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -176,17 +176,32 @@ void resize_sdpa_out( graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected)); } -void sdpa_with_kv_cache_impl( - ComputeGraph& graph, - const std::vector& args) { +void update_cache_impl(ComputeGraph& graph, const std::vector& args) { + int arg_idx = 0; + const ValueRef value = args[arg_idx++]; + const ValueRef cache = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef out = args[arg_idx++]; + + // Unused variables + (void)out; + + VK_CHECK_COND(graph.size_at(-4, value) == 1); + VK_CHECK_COND(graph.size_at(-4, cache) == 1); + VK_CHECK_COND( + graph.size_at(-1, value) == graph.size_at(-1, cache)); + VK_CHECK_COND( + graph.size_at(-2, value) == graph.size_at(-2, cache)); + + add_kv_cache_update_node(graph, input_pos_symint, value, cache); +} + +void sdpa_impl(ComputeGraph& graph, const std::vector& args) { int arg_idx = 0; const ValueRef q_projected = args[arg_idx++]; - const ValueRef k_projected = args[arg_idx++]; - const ValueRef v_projected = args[arg_idx++]; - const ValueRef k_cache_data = args[arg_idx++]; - const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef k_cache = args[arg_idx++]; + const ValueRef v_cache = args[arg_idx++]; const ValueRef input_pos_symint = args[arg_idx++]; - const ValueRef sequence_len = args[arg_idx++]; const ValueRef attn_mask = args[arg_idx++]; const ValueRef dropout_p = args[arg_idx++]; const ValueRef is_causal = args[arg_idx++]; @@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl( // Output tensors const ValueRef out = args[arg_idx++]; - // Unused variables - (void)sequence_len; - // Batches must be 1 VK_CHECK_COND(graph.size_at(-4, q_projected) == 1); - VK_CHECK_COND(graph.size_at(-4, k_projected) == 1); - VK_CHECK_COND(graph.size_at(-4, v_projected) == 1); + VK_CHECK_COND(graph.size_at(-4, k_cache) == 1); + VK_CHECK_COND(graph.size_at(-4, v_cache) == 1); // k and v projected must have the same shape - VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected)); + VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache)); // head dim must match between tensors VK_CHECK_COND( graph.size_at(-1, q_projected) == - graph.size_at(-1, k_projected)); + graph.size_at(-1, k_cache)); // All tensors must have the packed dim be the width (head) dimension VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim); // Some variables are not supported yet VK_CHECK_COND( graph.val_is_none(dropout_p) || @@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl( graph.val_is_none(is_causal) || graph.extract_scalar(is_causal)); VK_CHECK_COND(graph.val_is_none(attn_mask)); - const ValueRef k_cache = - prepack_standard_like(graph, k_cache_data, q_projected); - const ValueRef v_cache = - prepack_standard_like(graph, v_cache_data, q_projected); - const int32_t max_seq_len = graph.size_at(1, k_cache); - add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache); - add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache); - // Slice caches from 0 to input_pos + sequence_len const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache); const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache); @@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl( // Repeat interleave const int64_t num_heads = graph.size_at(2, q_projected); - const int64_t num_kv_heads = graph.size_at(2, k_projected); + const int64_t num_kv_heads = graph.size_at(2, k_cache); const ValueRef num_repeats = graph.add_scalar(num_heads / num_kv_heads); @@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl( new ExecuteNode(resize_sdpa_out, {q_projected, out})); } +void sdpa_with_kv_cache_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef q_projected = args[arg_idx++]; + const ValueRef k_projected = args[arg_idx++]; + const ValueRef v_projected = args[arg_idx++]; + const ValueRef k_cache_data = args[arg_idx++]; + const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef sequence_len = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + const ValueRef dropout_p = args[arg_idx++]; + const ValueRef is_causal = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + + // Output tensors + const ValueRef out = args[arg_idx++]; + + (void)sequence_len; + + const ValueRef k_cache = + prepack_standard_like(graph, k_cache_data, q_projected); + const ValueRef v_cache = + prepack_standard_like(graph, v_cache_data, q_projected); + + update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); + update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + + sdpa_impl( + graph, + {q_projected, + k_cache, + v_cache, + input_pos_symint, + attn_mask, + dropout_p, + is_causal, + scale, + out}); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); + VK_REGISTER_OP(update_cache.default, update_cache_impl); + VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); } } // namespace vkcompute