Skip to content

Commit

Permalink
Use the half-precision floating-point format as the data type for arg…
Browse files Browse the repository at this point in the history
…uments
  • Loading branch information
voltjia committed Jan 10, 2025
1 parent 0d220a5 commit 5ee0bb9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
9 changes: 5 additions & 4 deletions add.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def grid(meta):

torch.manual_seed(0)
size = 98432
lhs = torch.rand(size, device="cuda")
rhs = torch.rand(size, device="cuda")
dtype = torch.float16
lhs = torch.rand(size, dtype=dtype, device="cuda")
rhs = torch.rand(size, dtype=dtype, device="cuda")
ninetoothed_output = add(lhs, rhs)
torch_output = lhs + rhs
triton_output = triton_add(lhs, rhs)
Expand Down Expand Up @@ -92,8 +93,8 @@ def grid(meta):
)
)
def benchmark(size, provider):
lhs = torch.rand(size, device="cuda", dtype=torch.float32)
rhs = torch.rand(size, device="cuda", dtype=torch.float32)
lhs = torch.rand(size, device="cuda", dtype=torch.float16)
rhs = torch.rand(size, device="cuda", dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]

if provider == "ninetoothed":
Expand Down
10 changes: 6 additions & 4 deletions conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ def grid(meta):
torch.manual_seed(0)
n, c, h, w = 4, 3, 224, 224
k, _, r, s = 8, c, 3, 3
input = torch.randn(n, c, h, w, device="cuda")
filter = torch.randn(k, c, r, s, device="cuda")
dtype = torch.float16
input = torch.randn(n, c, h, w, dtype=dtype, device="cuda")
filter = torch.randn(k, c, r, s, dtype=dtype, device="cuda")
ninetoothed_output = conv2d(input, filter)
torch_output = F.conv2d(input, filter)
triton_output = triton_conv2d(input, filter)
Expand Down Expand Up @@ -233,8 +234,9 @@ def grid(meta):
def benchmark(h, w, provider):
n, c, _, _ = 64, 3, h, w
k, _, r, s = 64, c, 3, 3
input = torch.randn((n, c, h, w), device="cuda")
filter = torch.randn((k, c, r, s), device="cuda")
dtype = torch.float16
input = torch.randn((n, c, h, w), dtype=dtype, device="cuda")
filter = torch.randn((k, c, r, s), dtype=dtype, device="cuda")

if provider == "ninetoothed":
ms = triton.testing.do_bench(lambda: conv2d(input, filter))
Expand Down
6 changes: 3 additions & 3 deletions softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ def triton_softmax(input):


torch.manual_seed(0)
input = torch.randn(1823, 781, device="cuda")
input = torch.randn(1823, 781, dtype=torch.float16, device="cuda")
ninetoothed_output = softmax(input)
torch_output = torch.softmax(input, axis=-1)
triton_output = triton_softmax(input)
print(ninetoothed_output)
print(torch_output)
print(triton_output)
if torch.allclose(ninetoothed_output, torch_output):
if torch.allclose(ninetoothed_output, torch_output, atol=1e-5):
print("✅ NineToothed and PyTorch match.")
else:
print("❌ NineToothed and PyTorch differ.")
Expand All @@ -103,7 +103,7 @@ def triton_softmax(input):
)
)
def benchmark(m, n, provider):
input = torch.randn(m, n, device="cuda", dtype=torch.float32)
input = torch.randn(m, n, device="cuda", dtype=torch.float16)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)

Expand Down

0 comments on commit 5ee0bb9

Please sign in to comment.