Skip to content

Commit

Permalink
Don't modify Triton configs when using heuristics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 548521388
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Jul 16, 2023
1 parent d98c427 commit f5b7705
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
11 changes: 9 additions & 2 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,17 @@ def prune_configs(configs, named_args):
configs = [triton.Config({}, num_warps=num_warps, num_stages=num_stages)]

if isinstance(fn, autotuner.Heuristics):
updated_configs = []
for config in configs:
kwargs = config.kwargs.copy()
for name, heuristic in fn.values.items():
all_args = {**named_args, **metaparams, **config.kwargs}
config.kwargs[name] = heuristic(all_args)
kwargs[name] = heuristic({**named_args, **metaparams, **kwargs})
updated_configs.append(
triton.Config(
kwargs, num_warps=config.num_warps, num_stages=config.num_stages
)
)
configs = updated_configs
fn = fn.fn

if not isinstance(fn, triton.JITFunction):
Expand Down
23 changes: 23 additions & 0 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,29 @@ def do_matmul(m, n, k):
_ = do_matmul(m=128, n=128, k=160)
self.assertEqual(heuristic_returned_values, [True, True, True, False])

def test_heuristics_does_not_modify_autotune_configs(self):
def heuristic_fn(args):
return args["K"] % args["BLOCK_SIZE_K"] == 0

heuristics = {"K_EXACTLY_DIVISIBLE_BY_BLOCK": heuristic_fn}
autotune_config = triton.Config({"BLOCK_SIZE_K": 32}, num_warps=1)
kernel = triton.autotune([autotune_config], key=("M", "N", "K"))(
triton.heuristics(heuristics)(matmul_kernel)
)

def do_matmul(m, n, k):
x, y = create_random_inputs([m, k], [k, n])
return matmul(
x,
y,
kernel=kernel,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=32,
)

_ = do_matmul(m=128, n=128, k=128)
self.assertEqual(autotune_config.kwargs, {"BLOCK_SIZE_K": 32})

def test_autotune_with_input_output_aliasing(self):
autotune_configs = [
triton.Config({"BLOCK_SIZE": 32}, num_warps=1),
Expand Down

0 comments on commit f5b7705

Please sign in to comment.