Skip to content

Commit

Permalink
Matmul tutorial - pad weights only (#19)
Browse files Browse the repository at this point in the history
Adds option to apply padding only to matrix B.

This allows to explore potential speedups by limiting padding to
weights which is reasonably common strategy in e.g., ML inference.
Full padding still has to occur when K dimension is padded to avoid
dimension mismatch and/or meet power-of-two size requirement.
  • Loading branch information
adam-smnk authored and Devjiu committed Nov 13, 2024
1 parent e9e24f3 commit 1e7d9f3
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
CACHE_PADDING = False
PREPROCESS_EXTERNAL = False
XSMM_PAD = False
PAD_B_ONLY = False

@triton.jit
def matmul_kernel(
Expand Down Expand Up @@ -320,28 +321,30 @@ def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
k_block = min(triton.next_power_of_2(K), 1024)

if XSMM_PAD:
padding_size = (((K + k_block - 1) // k_block) * k_block) - K
k_dim_pad = (((K + k_block - 1) // k_block) * k_block) - K
col_pad = 32 if CACHE_PADDING else 0
a_scratch.resize_(M, K + padding_size + col_pad)
b_scratch.resize_(K + padding_size, N + col_pad)
xsmm_py.fastZeroPad2D(a, a_scratch)
a_scratch.resize_(M, K + k_dim_pad + col_pad)
b_scratch.resize_(K + k_dim_pad, N + col_pad)
if not PAD_B_ONLY or k_dim_pad != 0:
xsmm_py.fastZeroPad2D(a, a_scratch)
a = a_scratch
xsmm_py.fastZeroPad2D(b, b_scratch)
K = K + padding_size
a = a_scratch
b = b_scratch
K = K + k_dim_pad
else:
if K_DIM_PADDING or DYNAMIC_K_BLOCK:
padding_size = (((K + k_block - 1) // k_block) * k_block) - K
if padding_size != 0:
a = torch.nn.functional.pad(a, (0, padding_size, 0, 0), mode='constant', value=0)
b = torch.nn.functional.pad(b, (0, 0, 0, padding_size), mode='constant', value=0)
k_dim_pad = (((K + k_block - 1) // k_block) * k_block) - K
if k_dim_pad != 0:
a = torch.nn.functional.pad(a, (0, k_dim_pad, 0, 0), mode='constant', value=0)
b = torch.nn.functional.pad(b, (0, 0, 0, k_dim_pad), mode='constant', value=0)
K = a.shape[1]

# TODO: Check if padding is needed at all.
# Currently, cache padding is most useful together with dynamic K blocking
# to ensure that stride is non-power-of-two to improve cache behavior.
if CACHE_PADDING:
a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0)
if not PAD_B_ONLY:
a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0)
b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0)

#TODO: Currently masked load is not supported yet.
Expand Down

0 comments on commit 1e7d9f3

Please sign in to comment.