Skip to content

Commit

Permalink
[ET-VK][LlaMa] Split SDPA + KV cache operator into SDPA operator and …
Browse files Browse the repository at this point in the history
…KV cache update operator

## Context

#7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation.

As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators.

Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned.

Differential Revision: [D68916952](https://our.internmc.facebook.com/intern/diff/D68916952/)

[ghstack-poisoned]
  • Loading branch information
SS-JIA committed Jan 30, 2025
1 parent afc5a50 commit fd1cd73
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 31 deletions.
10 changes: 9 additions & 1 deletion backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
)
# This pass assumes that the SpecPropPass() has already been applied
assert "spec" in node.meta
# Mutable buffers will not be marked as constant, but it might as well be
# for the purposes of memory planning. Mark it as a constant tensor so that
# it is handled correctly by the memory planning pass.
if not node.meta["spec"].const:
assert is_param_node(program, node)
node.meta["spec"].const = True
# Validate that the original node is marked as a constant. Constant tensors
# do not participate in memory planning.
assert node.meta["spec"].const
Expand All @@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
# memory object.
prepack_node.meta["spec"].mem_obj_id = -1
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
node.replace_all_uses_with(
prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output")
)

program.graph.eliminate_dead_code()
return program
4 changes: 1 addition & 3 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:

# noqa
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))

for node in sorted_nodes:
for node in graph_module.graph.nodes:
if not self.should_annotate(node) or self.should_delay_annotation(node):
continue

Expand Down
16 changes: 15 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures):


@update_features("llama::sdpa_with_kv_cache")
def register_sdpa_op(features: OpFeatures):
def register_sdpa_with_kv_cache_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
Expand All @@ -489,6 +489,20 @@ def register_sdpa_op(features: OpFeatures):
return features


# TODO(ssjia) allow registration after remove assertions pass is implemented
# @update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default])
def register_sdpa_ops(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
features.resize_fn = False
features.buffer_impl = False
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
return features


@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
def register_rotary_emb_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
Expand Down
100 changes: 74 additions & 26 deletions backends/vulkan/runtime/graph/ops/impl/SDPA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,32 @@ void resize_sdpa_out(
graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected));
}

void sdpa_with_kv_cache_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
void update_cache_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef value = args[arg_idx++];
const ValueRef cache = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef out = args[arg_idx++];

// Unused variables
(void)out;

VK_CHECK_COND(graph.size_at<int32_t>(-4, value) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, cache) == 1);
VK_CHECK_COND(
graph.size_at<int32_t>(-1, value) == graph.size_at<int32_t>(-1, cache));
VK_CHECK_COND(
graph.size_at<int32_t>(-2, value) == graph.size_at<int32_t>(-2, cache));

add_kv_cache_update_node(graph, input_pos_symint, value, cache);
}

void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef q_projected = args[arg_idx++];
const ValueRef k_projected = args[arg_idx++];
const ValueRef v_projected = args[arg_idx++];
const ValueRef k_cache_data = args[arg_idx++];
const ValueRef v_cache_data = args[arg_idx++];
const ValueRef k_cache = args[arg_idx++];
const ValueRef v_cache = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef sequence_len = args[arg_idx++];
const ValueRef attn_mask = args[arg_idx++];
const ValueRef dropout_p = args[arg_idx++];
const ValueRef is_causal = args[arg_idx++];
Expand All @@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
// Output tensors
const ValueRef out = args[arg_idx++];

// Unused variables
(void)sequence_len;

// Batches must be 1
VK_CHECK_COND(graph.size_at<int32_t>(-4, q_projected) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_projected) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_projected) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_cache) == 1);
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_cache) == 1);
// k and v projected must have the same shape
VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected));
VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache));
// head dim must match between tensors
VK_CHECK_COND(
graph.size_at<int32_t>(-1, q_projected) ==
graph.size_at<int32_t>(-1, k_projected));
graph.size_at<int32_t>(-1, k_cache));
// All tensors must have the packed dim be the width (head) dimension
VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim);
// Some variables are not supported yet
VK_CHECK_COND(
graph.val_is_none(dropout_p) ||
Expand All @@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
graph.val_is_none(is_causal) || graph.extract_scalar<bool>(is_causal));
VK_CHECK_COND(graph.val_is_none(attn_mask));

const ValueRef k_cache =
prepack_standard_like(graph, k_cache_data, q_projected);
const ValueRef v_cache =
prepack_standard_like(graph, v_cache_data, q_projected);

const int32_t max_seq_len = graph.size_at<int32_t>(1, k_cache);

add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache);
add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache);

// Slice caches from 0 to input_pos + sequence_len
const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache);
const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache);
Expand All @@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(

// Repeat interleave
const int64_t num_heads = graph.size_at<int64_t>(2, q_projected);
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_projected);
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_cache);

const ValueRef num_repeats =
graph.add_scalar<int64_t>(num_heads / num_kv_heads);
Expand Down Expand Up @@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
new ExecuteNode(resize_sdpa_out, {q_projected, out}));
}

void sdpa_with_kv_cache_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef q_projected = args[arg_idx++];
const ValueRef k_projected = args[arg_idx++];
const ValueRef v_projected = args[arg_idx++];
const ValueRef k_cache_data = args[arg_idx++];
const ValueRef v_cache_data = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
const ValueRef sequence_len = args[arg_idx++];
const ValueRef attn_mask = args[arg_idx++];
const ValueRef dropout_p = args[arg_idx++];
const ValueRef is_causal = args[arg_idx++];
const ValueRef scale = args[arg_idx++];

// Output tensors
const ValueRef out = args[arg_idx++];

(void)sequence_len;

const ValueRef k_cache =
prepack_standard_like(graph, k_cache_data, q_projected);
const ValueRef v_cache =
prepack_standard_like(graph, v_cache_data, q_projected);

update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});

sdpa_impl(
graph,
{q_projected,
k_cache,
v_cache,
input_pos_symint,
attn_mask,
dropout_p,
is_causal,
scale,
out});
}

REGISTER_OPERATORS {
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
VK_REGISTER_OP(update_cache.default, update_cache_impl);
VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
}

} // namespace vkcompute

0 comments on commit fd1cd73

Please sign in to comment.