|
19 | 19 | {{template.header().getvalue()}}
|
20 | 20 | #include <ATen/native/CPUBlas.h>
|
21 | 21 |
|
22 |
| -{%- set kernel_args = {"query": query, "key": key, "value": value, "kv_indices": kv_indices} %} |
| 22 | +{%- set kernel_args = {"query": query, "key": key, "value": value, "kv_indices": kv_indices, "mask_other": mask_mod_other_buffers} %} |
23 | 23 | {{kernel.def_kernel(inputs=kernel_args, outputs={"output": output})}}
|
24 | 24 | {
|
25 | 25 | // kv page size, q and kv split size
|
|
166 | 166 | auto in_ptr1 = b_.data();
|
167 | 167 | auto in_ptr2 = h_.data();
|
168 | 168 | auto in_ptr3 = q_.data();
|
169 |
| - auto in_ptr4 = k_.data(); |
| 169 | + auto in_ptr10 = k_.data(); |
| 170 | + {%- if mask_mod_other_buffers %} |
| 171 | + auto in_ptr4 = mask_other; |
| 172 | + {%- endif %} |
170 | 173 | accum_t* out_ptr0 = in_ptr0;
|
171 | 174 | {{template.modification(score_mod)}}
|
172 | 175 | }
|
|
183 | 186 | auto in_ptr1 = h_.data();
|
184 | 187 | auto in_ptr2 = q_.data();
|
185 | 188 | auto in_ptr3 = k_.data();
|
| 189 | + {%- if mask_mod_other_buffers %} |
| 190 | + auto in_ptr4 = mask_other; |
| 191 | + {%- endif %} |
186 | 192 | std::vector<int64_t> temp = {0};
|
187 | 193 | int64_t* out_ptr0 = temp.data();
|
188 | 194 | {{template.modification(mask_mod)}}
|
@@ -295,12 +301,33 @@ def modification(self, subgraph_buffer):
|
295 | 301 | from ..utils import sympy_index_symbol_with_prefix, SymT
|
296 | 302 | from ..virtualized import ops, V
|
297 | 303 |
|
298 |
| - |
299 |
| - # TODO: what should be the output name?? |
300 |
| - output_name = "arg0_1" |
| 304 | + output_name = "buf0" |
| 305 | + V.graph.register_buffer(subgraph_buffer) |
301 | 306 |
|
302 | 307 | from .cpp import CppKernel, CppKernelProxy, KernelGroup
|
303 | 308 | kernel_group = KernelGroup()
|
| 309 | + kernel_input_args = { |
| 310 | + "arg0_1": "in_ptr0", |
| 311 | + "arg1_1": "in_ptr1", |
| 312 | + "arg2_1": "in_ptr2", |
| 313 | + "arg3_1": "in_ptr3", |
| 314 | + "arg10_1": "in_ptr10", |
| 315 | + "arg4_1": "in_ptr4", |
| 316 | + } |
| 317 | + |
| 318 | + kernel_output_args = { |
| 319 | + "buf0": "out_ptr0" |
| 320 | + } |
| 321 | + |
| 322 | + args = kernel_group.args |
| 323 | + for name, inp in kernel_input_args.items(): |
| 324 | + args.input_buffers[name] = inp |
| 325 | + |
| 326 | + for name, inp in kernel_output_args.items(): |
| 327 | + args.output_buffers[name] = inp |
| 328 | + |
| 329 | + kernel_group.args = args |
| 330 | + |
304 | 331 | cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
305 | 332 | bodies = []
|
306 | 333 | var_sizes_list = []
|
@@ -407,11 +434,14 @@ def render( # type: ignore[override,return]
|
407 | 434 |
|
408 | 435 | if template_buffer_node is not None:
|
409 | 436 | buf_out = template_buffer_node
|
| 437 | + has_other_buffer = len(self.input_nodes) == 6 |
410 | 438 | options = dict(
|
411 | 439 | query=query,
|
412 | 440 | key=key,
|
413 | 441 | value=value,
|
414 | 442 | kv_indices=self.input_nodes[3],
|
| 443 | + score_mod_other_buffers=self.input_nodes[4] if has_other_buffer else None, |
| 444 | + mask_mod_other_buffers=self.input_nodes[5] if has_other_buffer else None, |
415 | 445 | scale=self.scale,
|
416 | 446 | size_per_thread=size_per_thread,
|
417 | 447 | accumulate_dtype=torch.float,
|
|
0 commit comments