Skip to content

Commit

Permalink
[RFC] Implement caching for user defined triton kernels (pytorch#140326)
Browse files Browse the repository at this point in the history
This PR adds caching for user defined triton kernels by putting the transitive closure of source code in node.meta along with constant arguments.

One HUGE hack we do here is a node looks like
```
triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 1, grid = [(1, 1, 1)], tma_descriptor_
metadata = {}, kwargs = {'in_ptr0': arg0_1, 'in_ptr1': arg1_1, 'out_ptr': arg0_1}, tensors_to_clone = ['out_ptr']);
```
so we use regex to remove `kernel_idx = 0, constant_args_idx = 1` parts as they are not relevant to cache hash. This is horrible and I'd like to eventually not use pickle as a hashing alternative but this is a longer project.

Differential Revision: [D65895744](https://our.internmc.facebook.com/intern/diff/D65895744)
Pull Request resolved: pytorch#140326
Approved by: https://github.com/zou3519
  • Loading branch information
oulgen authored and pytorchmergebot committed Nov 16, 2024
1 parent 48a55b8 commit a173186
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 26 deletions.
107 changes: 89 additions & 18 deletions test/inductor/test_codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
if HAS_TRITON:
import triton # @manual

from torch.testing._internal.triton_utils import add_kernel
from torch.testing._internal.triton_utils import add_kernel, sub_kernel

torch._dynamo.config.fake_tensor_cache_enabled = True
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
Expand Down Expand Up @@ -494,13 +494,41 @@ def fn2(q, k, v):
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("bundle_triton", (False, True))
@parametrize("grad", (False, True))
def test_triton_higher_order_op_bypass(self, bundle_triton, grad):
def test_higher_order_op_bypass(self, bundle_triton):
"""
Verify that we bypass the cache when we have a triton higher order ops
Verify that we bypass the cache when we have a higher order ops
and that bundler start/end works with a cache bypass.
"""

def fn(x):
def true_fn(x: torch.Tensor):
return x.cos()

def false_fn(x: torch.Tensor):
return x.sin()

return torch.cond(x.shape[0], true_fn, false_fn, (x,))

with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
compiled_fn = torch.compile(fn, dynamic=True, fullgraph=True)

x = torch.randn(4, 4, device=GPU_TYPE)
result = compiled_fn(x)

self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)

@requires_gpu()
@requires_triton()
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@parametrize("bundle_triton", (False, True))
def test_triton_higher_order_op(self, bundle_triton):
"""
Verify that we can cache user defined triton kernel higher order op
"""

def fn(x, y):
n_elements = x.numel()
grid = lambda meta: ( # noqa: E731
Expand All @@ -509,18 +537,54 @@ def fn(x, y):
add_kernel[grid](x, y, x, n_elements, BLOCK_SIZE=4)
return x

def fn2(x, y):
n_elements = x.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
sub_kernel[grid](x, y, x, n_elements, BLOCK_SIZE=4)
return x

with config.patch(bundle_triton_into_fx_graph_cache=bundle_triton):
compiled_fn = torch.compile(fn, fullgraph=True)
compiled_fn2 = torch.compile(fn2, fullgraph=True)

x = torch.randn(4, device=GPU_TYPE)
y = torch.randn(4, device=GPU_TYPE)

x = torch.randn(4, device=GPU_TYPE, requires_grad=grad)
y = torch.randn(4, device=GPU_TYPE, requires_grad=grad)
result = compiled_fn(x, y)
if grad:
result.sum().backward()

self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)

# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()

# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)

result = compiled_fn(x, y)

self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)

# A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
self.reset()

# Clean PyCodeCache and triton kernels
PyCodeCache.cache_clear()
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)

result = compiled_fn2(x, y)

self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)

@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
Expand Down Expand Up @@ -808,15 +872,16 @@ def test_tensor_constants(self):
self.assertFalse(GraphLowering.can_inline_constant(large))

# By default, we hash the metadata and values independent of the size.
pickler = FxGraphCachePickler()
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)

data = pickler.dumps(small)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
data = pickler.dumps(large)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)

# If include_non_inlined=False, we only hash the values of small tensors.
pickler = FxGraphCachePickler(False)
pickler = FxGraphCachePickler(gm, False)

data = pickler.dumps(small)
self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
Expand All @@ -827,7 +892,8 @@ def test_hash_fake_tensors(self):
"""
Test hashing (pickling) FakeTensors with various characteristics.
"""
pickler = FxGraphCachePickler()
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
with torch._subclasses.FakeTensorMode():
# Verify that FakeTensors get pickled into a TensorMetadata:
data = pickler.dumps(torch.randn(1))
Expand Down Expand Up @@ -933,7 +999,8 @@ def test_hash_kwargs(self):
Test the special handling of the kwargs when hashing, i.e.,
ordering of the kwargs dict and any set arguments.
"""
pickler = FxGraphCachePickler()
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)

# Dict order of the kwargs should not affect hashes.
details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1}, [])
Expand Down Expand Up @@ -981,7 +1048,8 @@ def test_hash_config_changes(self):
with config.patch({"max_autotune": True}):
details3 = FxGraphHashDetails(None, [], {}, [])

pickler = FxGraphCachePickler()
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)

self.assertEqual(
pickler.dumps(details1),
Expand Down Expand Up @@ -1016,7 +1084,8 @@ def uuid(self) -> Optional[Union[bytes, str]]:
custom_pass._uuid = "2"
details3 = FxGraphHashDetails(None, [], {}, [])

pickler = FxGraphCachePickler()
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)

self.assertEqual(
pickler.dumps(details1),
Expand All @@ -1031,8 +1100,9 @@ def test_bypass_unsupported(self):
"""
Test _reduce_unsupported
"""
gm = torch.fx.GraphModule({}, torch.fx.Graph())
with self.assertRaises(BypassFxGraphCache):
FxGraphCachePickler().dumps(
FxGraphCachePickler(gm).dumps(
torch.fx.experimental._backward_state.BackwardState()
)

Expand All @@ -1047,7 +1117,8 @@ def test_stable_strings(self):

self.assertNotEqual(id(s1), id(s2))

pickler = FxGraphCachePickler()
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps([s1, s1]),
pickler.dumps([s1, s2]),
Expand Down
6 changes: 3 additions & 3 deletions torch/_functorch/_aot_autograd/autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def __init__(


class AOTAutogradCachePickler(FxGraphCachePickler):
def __init__(self):
super().__init__()
def __init__(self, gm: torch.fx.GraphModule):
super().__init__(gm)
self.dispatch_table: Dict
self.dispatch_table.update(
{
Expand Down Expand Up @@ -275,7 +275,7 @@ def autograd_cache_key(
"""
check_cacheable(gm)
details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config)
pickler = AOTAutogradCachePickler()
pickler = AOTAutogradCachePickler(gm)
# The prefix distinguishes among the other kinds of objects we cache
key = "a" + pickler.get_hash(details)
debug_lines = pickler.debug_lines(details)
Expand Down
24 changes: 22 additions & 2 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def identify_mutated_tensors(
# Used for wrapping a Triton Kernel
class TritonKernelWrapperMutation(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_mutation", cacheable=False)
super().__init__("triton_kernel_wrapper_mutation", cacheable=True)

def __call__(
self,
Expand All @@ -638,7 +638,7 @@ def __call__(
# Used for wrapping a Triton Kernel in a functional manner
class TritonKernelWrapperFunctional(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_functional", cacheable=False)
super().__init__("triton_kernel_wrapper_functional", cacheable=True)

def __call__(
self,
Expand Down Expand Up @@ -774,6 +774,26 @@ def trace_triton_kernel_wrapper(
proxy_args,
name=func_overload.__name__ + "_proxy",
)

from triton.runtime.autotuner import Autotuner

from torch._inductor.codegen.wrapper import (
user_defined_triton_kernel_transitive_closure_source_code,
)

kernel = kernel_side_table.get_kernel(proxy_args["kernel_idx"])
if isinstance(kernel, Autotuner):
kernel = kernel.fn

kernel_source = user_defined_triton_kernel_transitive_closure_source_code(kernel)
constant_args = kernel_side_table.get_constant_args(proxy_args["constant_args_idx"])
# we add to node here so that it gets included in the inductor cache key
# when the graph is pickled
out_proxy.node.meta["user_defined_triton_kernel_source_and_constant_args"] = (
kernel_source,
constant_args,
)

ret = track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
return ret

Expand Down
66 changes: 64 additions & 2 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import hashlib
import importlib
import io
import itertools
import json
import logging
import os
Expand Down Expand Up @@ -525,7 +526,12 @@ class FxGraphCachePickler(pickle.Pickler):
data that allow us to compute a stable, but safe hash.
"""

def __init__(self, include_non_inlined: bool = True) -> None:
def __init__(
self,
gm: torch.fx.GraphModule,
include_non_inlined: bool = True,
has_user_defined_triton_kernels: bool = False,
) -> None:
"""
Create an FX graph pickler. If include_non_inlined=True, then pickling will
include the _values_ for all Tensors. (Note that any tensors are constants
Expand All @@ -548,6 +554,11 @@ def __init__(self, include_non_inlined: bool = True) -> None:
),
}
)
if has_user_defined_triton_kernels:
# Need to use runtime type as GraphModule generates a singleton in __new__ function
self.dispatch_table[gm.__class__] = functools.partial(
self._reduce_graph_module
)

# Run with pickler.fast so it doesn't intern strings, making the hash result more predictable
# TODO: pickler.fast is technically deprecated. Will this work on new python versions?
Expand Down Expand Up @@ -614,6 +625,25 @@ def _reduce_unsupported(self, s: Any) -> NoReturn:
"""
raise BypassFxGraphCache("Reduce unsupported")

def _reduce_graph_module(
self, gm: torch.fx.GraphModule
) -> Tuple[Any, Tuple[Dict[str, Any], str]]:
"""
Custom reducer for graph module to handle irrelevant data for user
defined triton kernels
Essentially what we are doing here is a huge hack where user defined
triton kernel contain a dynamo time side table and the arguments to the
call_function are indicies into this side table. These arguments are not
for hashing purposes since we included the source code into the cache
key and the numbers are prone to give false negatives due to ordering.
"""
fn, (data, imports) = gm.__reduce__()
code = data["_code"]
code = re.sub(r"kernel_idx = \d+", "", code)
code = re.sub(r"constant_args_idx = \d+", "", code)
data["_code"] = code
return fn, (data, imports)

def dumps(self, obj: Any) -> bytes:
"""
Pickle an object and return a byte string.
Expand Down Expand Up @@ -775,6 +805,35 @@ def __init__(
else:
self.fx_kwargs[k] = v

from torch._higher_order_ops.triton_kernel_wrap import (
triton_kernel_wrapper_functional,
triton_kernel_wrapper_mutation,
)

# Node meta will not be part of gm's reduce function, so lets remember
# the kernel source code separately
self.user_defined_triton_source: List[Any] = []
if gm is not None:
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in itertools.chain(
module.graph.find_nodes(
op="call_function", target=triton_kernel_wrapper_functional
),
module.graph.find_nodes(
op="call_function", target=triton_kernel_wrapper_mutation
),
):
data = node.meta.get(
"user_defined_triton_kernel_source_and_constant_args", None
)
if data is None:
raise AssertionError(
"TritonKernelWrapper does not contain source code meta"
)
self.user_defined_triton_source.append(data)

# Alignment checks
self.inputs_to_check = inputs_to_check

Expand Down Expand Up @@ -833,7 +892,10 @@ def compiled_fx_graph_hash(
include_non_inlined = not has_frozen_params(gm)

details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
pickler = FxGraphCachePickler(include_non_inlined)
has_user_defined_triton_kernels = len(details.user_defined_triton_source) != 0
pickler = FxGraphCachePickler(
gm, include_non_inlined, has_user_defined_triton_kernels
)
# The prefix distinguishes among the other kinds of objects we
# cache in this module.
key = "f" + pickler.get_hash(details)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def writeline(line: str, example_grid: Optional[str] = None):
def user_defined_triton_kernel_transitive_closure_source_code(kernel) -> str:
"""
Given a triton kernel function pointer collect the transitive closure of
its dependancies
its dependencies
"""
compile_wrapper = IndentedBuffer()
compile_wrapper.splice(kernel.src, strip=True)
Expand Down
Loading

0 comments on commit a173186

Please sign in to comment.