diff --git a/matmul.py b/matmul.py index 4880a0a..1b0f04b 100644 --- a/matmul.py +++ b/matmul.py @@ -5,37 +5,43 @@ import triton.language as tl from ninetoothed import Symbol, Tensor -BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) -BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) -BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) -output_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) +def arrangement(lhs, rhs, output): + BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True) + BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True) + BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True) -lhs_tiled = ( - Tensor(2) - .tile((BLOCK_SIZE_M, BLOCK_SIZE_K)) - .tile((1, -1)) - .expand((-1, output_tiled.shape[1])) -) -lhs_tiled.dtype = lhs_tiled.dtype.squeeze(0) + output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N)) -rhs_tiled = ( - Tensor(2) - .tile((BLOCK_SIZE_K, BLOCK_SIZE_N)) - .tile((-1, 1)) - .expand((output_tiled.shape[0], -1)) -) -rhs_tiled.dtype = rhs_tiled.dtype.squeeze(1) + lhs_tiled = ( + lhs.tile((BLOCK_SIZE_M, BLOCK_SIZE_K)) + .tile((1, -1)) + .expand((-1, output_tiled.shape[1])) + ) + lhs_tiled.dtype = lhs_tiled.dtype.squeeze(0) + + rhs_tiled = ( + rhs.tile((BLOCK_SIZE_K, BLOCK_SIZE_N)) + .tile((-1, 1)) + .expand((output_tiled.shape[0], -1)) + ) + rhs_tiled.dtype = rhs_tiled.dtype.squeeze(1) + return lhs_tiled, rhs_tiled, output_tiled -@ninetoothed.jit -def matmul_kernel(lhs: lhs_tiled, rhs: rhs_tiled, output: output_tiled): + +def application(lhs, rhs, output): accumulator = ntl.zeros(output.shape, dtype=ntl.float32) for k in range(lhs.shape[0]): accumulator += ntl.dot(lhs[k], rhs[k]) output = accumulator.to(ntl.float16) +matmul_kernel = ninetoothed.make( + arrangement, application, (Tensor(2), Tensor(2), Tensor(2)) +) + + def matmul(lhs, rhs): output = torch.empty( (lhs.shape[0], rhs.shape[1]), device=lhs.device, dtype=torch.float16