Skip to content

Commit

Permalink
Add correctness verification during performance testing
Browse files Browse the repository at this point in the history
  • Loading branch information
voltjia committed Jan 12, 2025
1 parent ee28ccb commit 11d09dd
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 1 deletion.
6 changes: 6 additions & 0 deletions add.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def benchmark(size, provider):
rhs = torch.rand(size, device="cuda", dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]

ninetoothed_output = add(lhs, rhs)
torch_output = lhs + rhs
triton_output = triton_add(lhs, rhs)
assert torch.allclose(ninetoothed_output, torch_output)
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)

if provider == "ninetoothed":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: add(lhs, rhs), quantiles=quantiles
Expand Down
6 changes: 6 additions & 0 deletions attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ def benchmark(seq_len, provider):
k = torch.randn(shape, dtype=dtype, device="cuda")
v = torch.randn(shape, dtype=dtype, device="cuda")

ninetoothed_output = attention(q, k, v)
torch_output = F.scaled_dot_product_attention(q, k, v, scale=1)
triton_output = triton_attention(q, k, v)
assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025)
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)

if provider == "ninetoothed":
ms = triton.testing.do_bench(lambda: attention(q, k, v))
elif provider == "torch":
Expand Down
6 changes: 6 additions & 0 deletions conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ def benchmark(n, provider):
input = torch.randn((n, c, h, w), dtype=dtype, device="cuda")
filter = torch.randn((k, c, r, s), dtype=dtype, device="cuda")

ninetoothed_output = conv2d(input, filter)
torch_output = F.conv2d(input, filter)
triton_output = triton_conv2d(input, filter)
assert torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01)
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)

if provider == "ninetoothed":
ms = triton.testing.do_bench(lambda: conv2d(input, filter))
elif provider == "torch":
Expand Down
6 changes: 6 additions & 0 deletions matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ def benchmark(m, n, k, provider):
rhs = torch.randn((k, n), device="cuda", dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]

ninetoothed_output = matmul(lhs, rhs)
torch_output = torch.matmul(lhs, rhs)
triton_output = triton_matmul(lhs, rhs)
assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025)
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)

if provider == "ninetoothed":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(lhs, rhs), quantiles=quantiles
Expand Down
6 changes: 6 additions & 0 deletions rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def triton_rms_norm(input, eps=1e-5):
def benchmark(m, n, provider):
input = torch.randn(m, n, dtype=torch.float16, device="cuda")

ninetoothed_output = rms_norm(input)
torch_output = F.rms_norm(input, input.shape[-1:])
triton_output = triton_rms_norm(input)
assert torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005)
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)

if provider == "ninetoothed":
ms = triton.testing.do_bench(lambda: rms_norm(input))
elif provider == "torch":
Expand Down
8 changes: 7 additions & 1 deletion softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def triton_softmax(input):
print(ninetoothed_output)
print(torch_output)
print(triton_output)
if torch.allclose(ninetoothed_output, torch_output, atol=1e-5):
if torch.allclose(ninetoothed_output, torch_output, atol=0.001):
print("✅ NineToothed and PyTorch match.")
else:
print("❌ NineToothed and PyTorch differ.")
Expand All @@ -105,6 +105,12 @@ def benchmark(m, n, provider):
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)

ninetoothed_output = softmax(input)
torch_output = torch.softmax(input, axis=-1)
triton_output = triton_softmax(input)
assert torch.allclose(ninetoothed_output, torch_output, atol=0.001)
assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0)

if provider == "ninetoothed":
ms = triton.testing.do_bench(lambda: softmax(input))
elif provider == "torch":
Expand Down

0 comments on commit 11d09dd

Please sign in to comment.