From 1e7d9f3ad35e2bf8a1ae7fef3ac4c14b847be594 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 13 Nov 2024 15:11:17 +0100 Subject: [PATCH] Matmul tutorial - pad weights only (#19) Adds option to apply padding only to matrix B. This allows to explore potential speedups by limiting padding to weights which is reasonably common strategy in e.g., ML inference. Full padding still has to occur when K dimension is padded to avoid dimension mismatch and/or meet power-of-two size requirement. --- .../tutorials/03-matrix-multiplication-cpu.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 83b98eea12b78..5ee392efa46ce 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -170,6 +170,7 @@ CACHE_PADDING = False PREPROCESS_EXTERNAL = False XSMM_PAD = False +PAD_B_ONLY = False @triton.jit def matmul_kernel( @@ -320,28 +321,30 @@ def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): k_block = min(triton.next_power_of_2(K), 1024) if XSMM_PAD: - padding_size = (((K + k_block - 1) // k_block) * k_block) - K + k_dim_pad = (((K + k_block - 1) // k_block) * k_block) - K col_pad = 32 if CACHE_PADDING else 0 - a_scratch.resize_(M, K + padding_size + col_pad) - b_scratch.resize_(K + padding_size, N + col_pad) - xsmm_py.fastZeroPad2D(a, a_scratch) + a_scratch.resize_(M, K + k_dim_pad + col_pad) + b_scratch.resize_(K + k_dim_pad, N + col_pad) + if not PAD_B_ONLY or k_dim_pad != 0: + xsmm_py.fastZeroPad2D(a, a_scratch) + a = a_scratch xsmm_py.fastZeroPad2D(b, b_scratch) - K = K + padding_size - a = a_scratch b = b_scratch + K = K + k_dim_pad else: if K_DIM_PADDING or DYNAMIC_K_BLOCK: - padding_size = (((K + k_block - 1) // k_block) * k_block) - K - if padding_size != 0: - a = torch.nn.functional.pad(a, (0, padding_size, 0, 0), mode='constant', value=0) - b = torch.nn.functional.pad(b, (0, 0, 0, padding_size), mode='constant', value=0) + k_dim_pad = (((K + k_block - 1) // k_block) * k_block) - K + if k_dim_pad != 0: + a = torch.nn.functional.pad(a, (0, k_dim_pad, 0, 0), mode='constant', value=0) + b = torch.nn.functional.pad(b, (0, 0, 0, k_dim_pad), mode='constant', value=0) K = a.shape[1] # TODO: Check if padding is needed at all. # Currently, cache padding is most useful together with dynamic K blocking # to ensure that stride is non-power-of-two to improve cache behavior. if CACHE_PADDING: - a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0) + if not PAD_B_ONLY: + a = torch.nn.functional.pad(a, (0, 32, 0, 0), mode='constant', value=0) b = torch.nn.functional.pad(b, (0, 32, 0, 0), mode='constant', value=0) #TODO: Currently masked load is not supported yet.