Skip to content

Commit

Permalink
Use ninetoothed.make in matmul.py
Browse files Browse the repository at this point in the history
  • Loading branch information
voltjia committed Dec 27, 2024
1 parent f2cf316 commit cb9fa62
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,43 @@
import triton.language as tl
from ninetoothed import Symbol, Tensor

BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)

output_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
def arrangement(lhs, rhs, output):
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)

lhs_tiled = (
Tensor(2)
.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
.tile((1, -1))
.expand((-1, output_tiled.shape[1]))
)
lhs_tiled.dtype = lhs_tiled.dtype.squeeze(0)
output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))

rhs_tiled = (
Tensor(2)
.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
.tile((-1, 1))
.expand((output_tiled.shape[0], -1))
)
rhs_tiled.dtype = rhs_tiled.dtype.squeeze(1)
lhs_tiled = (
lhs.tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
.tile((1, -1))
.expand((-1, output_tiled.shape[1]))
)
lhs_tiled.dtype = lhs_tiled.dtype.squeeze(0)

rhs_tiled = (
rhs.tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
.tile((-1, 1))
.expand((output_tiled.shape[0], -1))
)
rhs_tiled.dtype = rhs_tiled.dtype.squeeze(1)

return lhs_tiled, rhs_tiled, output_tiled

@ninetoothed.jit
def matmul_kernel(lhs: lhs_tiled, rhs: rhs_tiled, output: output_tiled):

def application(lhs, rhs, output):
accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
for k in range(lhs.shape[0]):
accumulator += ntl.dot(lhs[k], rhs[k])
output = accumulator.to(ntl.float16)


matmul_kernel = ninetoothed.make(
arrangement, application, (Tensor(2), Tensor(2), Tensor(2))
)


def matmul(lhs, rhs):
output = torch.empty(
(lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16
Expand Down

0 comments on commit cb9fa62

Please sign in to comment.