Skip to content

Commit

Permalink
Update benchmark x_vals ranges and use log scaling for performance …
Browse files Browse the repository at this point in the history
…testing
  • Loading branch information
voltjia committed Jan 12, 2025
1 parent 11d09dd commit 9d5c731
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion add.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def grid(meta):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["size"],
x_vals=[2**i for i in range(12, 28, 1)],
x_vals=[2**i for i in range(18, 28)],
x_log=True,
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
Expand Down
3 changes: 2 additions & 1 deletion attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def grid(meta):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["seq_len"],
x_vals=[2**i for i in range(10, 15)],
x_vals=[2**i for i in range(6, 16)],
x_log=True,
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
Expand Down
3 changes: 2 additions & 1 deletion conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def grid(meta):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["n"],
x_vals=[2**i for i in range(11)],
x_vals=[2**i for i in range(1, 11)],
x_log=True,
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
Expand Down
3 changes: 2 additions & 1 deletion matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def grid(meta):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["m", "n", "k"],
x_vals=[128 * i for i in range(2, 33)],
x_vals=[2**i for i in range(3, 13)],
x_log=True,
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
Expand Down
3 changes: 2 additions & 1 deletion rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def triton_rms_norm(input, eps=1e-5):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["n"],
x_vals=[512 * i for i in range(2, 32)],
x_vals=[2**i for i in range(5, 15)],
x_log=True,
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
Expand Down
3 changes: 2 additions & 1 deletion softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def triton_softmax(input):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["n"],
x_vals=[128 * i for i in range(2, 100)],
x_vals=[2**i for i in range(5, 15)],
x_log=True,
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
Expand Down

0 comments on commit 9d5c731

Please sign in to comment.