Skip to content

Commit d6a55bc

Browse files
authored
Merge pull request #4 from chunyuan-w/chunyuan/flex_fix
fix arg mapping; remove assert; support other buffer
2 parents b251b7d + 844da1c commit d6a55bc

File tree

3 files changed

+44
-10
lines changed

3 files changed

+44
-10
lines changed

torch/_inductor/codegen/cpp.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1906,7 +1906,7 @@ def load(self, name: str, index: sympy.Expr):
19061906
return csevar
19071907

19081908
def store(self, name, index, value, mode=None):
1909-
# assert "buf" in name
1909+
assert "buf" in name
19101910
var = self.args.output(name)
19111911
index = self.rename_indexing(index)
19121912
if mode is None:
@@ -3132,8 +3132,7 @@ def load(self, name: str, index: sympy.Expr):
31323132
return super().load(name, new_index)
31333133

31343134
def store(self, name, index, value, mode=None):
3135-
# TODO: fix me
3136-
# assert "buf" in name
3135+
assert "buf" in name
31373136
var = self.args.output(name)
31383137

31393138
inner = self.inner_itervar()

torch/_inductor/codegen/cpp_mha_template.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
{{template.header().getvalue()}}
2020
#include <ATen/native/CPUBlas.h>
2121
22-
{%- set kernel_args = {"query": query, "key": key, "value": value, "kv_indices": kv_indices} %}
22+
{%- set kernel_args = {"query": query, "key": key, "value": value, "kv_indices": kv_indices, "mask_other": mask_mod_other_buffers} %}
2323
{{kernel.def_kernel(inputs=kernel_args, outputs={"output": output})}}
2424
{
2525
// kv page size, q and kv split size
@@ -166,7 +166,10 @@
166166
auto in_ptr1 = b_.data();
167167
auto in_ptr2 = h_.data();
168168
auto in_ptr3 = q_.data();
169-
auto in_ptr4 = k_.data();
169+
auto in_ptr10 = k_.data();
170+
{%- if mask_mod_other_buffers %}
171+
auto in_ptr4 = mask_other;
172+
{%- endif %}
170173
accum_t* out_ptr0 = in_ptr0;
171174
{{template.modification(score_mod)}}
172175
}
@@ -183,6 +186,9 @@
183186
auto in_ptr1 = h_.data();
184187
auto in_ptr2 = q_.data();
185188
auto in_ptr3 = k_.data();
189+
{%- if mask_mod_other_buffers %}
190+
auto in_ptr4 = mask_other;
191+
{%- endif %}
186192
std::vector<int64_t> temp = {0};
187193
int64_t* out_ptr0 = temp.data();
188194
{{template.modification(mask_mod)}}
@@ -295,12 +301,33 @@ def modification(self, subgraph_buffer):
295301
from ..utils import sympy_index_symbol_with_prefix, SymT
296302
from ..virtualized import ops, V
297303

298-
299-
# TODO: what should be the output name??
300-
output_name = "arg0_1"
304+
output_name = "buf0"
305+
V.graph.register_buffer(subgraph_buffer)
301306

302307
from .cpp import CppKernel, CppKernelProxy, KernelGroup
303308
kernel_group = KernelGroup()
309+
kernel_input_args = {
310+
"arg0_1": "in_ptr0",
311+
"arg1_1": "in_ptr1",
312+
"arg2_1": "in_ptr2",
313+
"arg3_1": "in_ptr3",
314+
"arg10_1": "in_ptr10",
315+
"arg4_1": "in_ptr4",
316+
}
317+
318+
kernel_output_args = {
319+
"buf0": "out_ptr0"
320+
}
321+
322+
args = kernel_group.args
323+
for name, inp in kernel_input_args.items():
324+
args.input_buffers[name] = inp
325+
326+
for name, inp in kernel_output_args.items():
327+
args.output_buffers[name] = inp
328+
329+
kernel_group.args = args
330+
304331
cpp_kernel_proxy = CppKernelProxy(kernel_group)
305332
bodies = []
306333
var_sizes_list = []
@@ -407,11 +434,14 @@ def render( # type: ignore[override,return]
407434

408435
if template_buffer_node is not None:
409436
buf_out = template_buffer_node
437+
has_other_buffer = len(self.input_nodes) == 6
410438
options = dict(
411439
query=query,
412440
key=key,
413441
value=value,
414442
kv_indices=self.input_nodes[3],
443+
score_mod_other_buffers=self.input_nodes[4] if has_other_buffer else None,
444+
mask_mod_other_buffers=self.input_nodes[5] if has_other_buffer else None,
415445
scale=self.scale,
416446
size_per_thread=size_per_thread,
417447
accumulate_dtype=torch.float,

torch/_inductor/kernel/flex_attention.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def flex_attention(
778778
("arg1_1", torch.int32),
779779
("arg2_1", torch.int32),
780780
("arg3_1", torch.int32),
781-
("arg4_1", torch.int32),
781+
("arg10_1", torch.int32), # TODO: fix the random picked name here: arg10_1
782782
]
783783
]
784784
subgraph_buffer = build_subgraph_buffer(
@@ -812,9 +812,14 @@ def flex_attention(
812812
stride=out_strides,
813813
)
814814
choices: List[Any] = []
815+
input_nodes = [query, key, value, kv_indices]
816+
if score_mod_other_buffers and mask_mod_other_buffers:
817+
assert len(score_mod_other_buffers) == 1
818+
assert len(mask_mod_other_buffers) == 1
819+
input_nodes += [score_mod_other_buffers[0], mask_mod_other_buffers[0]]
815820
CppMHATemplate.add_choices(
816821
choices=choices,
817-
input_nodes=[query, key, value, kv_indices],
822+
input_nodes=input_nodes,
818823
layout=layout,
819824
scale=scale,
820825
score_mod=subgraph_buffer,

0 commit comments

Comments
 (0)