Skip to content

Commit

Permalink
RMSNorm Blocked Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Batra committed Sep 24, 2024
1 parent e13fc4c commit 44e9360
Showing 1 changed file with 36 additions and 29 deletions.
65 changes: 36 additions & 29 deletions python/perf-kernels/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,47 @@ def get_autotune_config():

@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, eps,
BLOCK_SIZE: tl.constexpr):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
for row_idx in tl.range(row_start, n_rows, row_step):
row_start_ptr = input_ptr + row_idx * input_row_stride
row_idx = row_start

#Calculate squared mean by block
row_start_ptr = input_ptr + row_idx * input_row_stride
row_sum = 0.0
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
row_block = row_block * row_block #square every value the block
row_sum += (tl.sum(row_block, axis=-1) / n_cols
) #tl.sum across row, divide by block_size and add it running sum

row_norm = row_sum + eps
row_norm = tl.rsqrt(row_norm)

#Blocked normalization
output_row_start_ptr = output_ptr + row_idx * output_row_stride
for b in tl.range(0, n_cols, BLOCK_SIZE):
col_offsets = b + tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg")
row_norm = row * row #square each value
row_norm = tl.sum(row_norm, axis=-1) #sum across columns(axis=-1)
row_norm = row_norm / n_cols #divide by n_cols
row_norm = row_norm + epsilon #add epsilon
row_norm = tl.rsqrt(row_norm) #take rsqrt, this is normalization value
rms_norm = row * row_norm #multiply each x by normalization value
rms_norm = rms_norm * g #element wise multiplication with g

output_row_start_ptr = output_ptr + row_idx * output_row_stride
mask = col_offsets < n_cols
row_block = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") #load block of input
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0, cache_modifier=".cg") #load block of g
output = row_block * row_norm #element wise multiply with rms_norm
output = output * g #element wise multiplication with g

output_ptrs = output_row_start_ptr + col_offsets
output_ptrs = tl.multiple_of(output_ptrs, (16, ))
tl.store(output_ptrs, rms_norm, mask=mask)
tl.store(output_ptrs, output, mask=mask)


def rmsnorm(x, epsilon=1e-6):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
#Restricting BLOCK_SIZE to 64Kb is an important optimization. Otherwise,
#performance can drop significantly for larger n_cols.
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))

y = torch.empty_like(x, device='cuda')
g = torch.ones((1, n_cols), device='cuda')
Expand All @@ -87,21 +99,15 @@ def rmsnorm(x, epsilon=1e-6):


def run_rmsnorm(M, N):
print(f"Running RMSNorm for shape ({M}, {N})")
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
y_triton = rmsnorm(x)

return y_triton


@pytest.mark.parametrize('M, N', [
(1, 4),
(2, 10),
(8192, 4096),
(4096, 8192),
(1, 8192),
(873, 1245),
])
@pytest.mark.parametrize('M, N', [(1, 4), (2, 10), (8192, 4096), (4096, 8192), (1, 8192), (873, 1245), (1, 98304)])
def test_rmsnorm(M, N):
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
Expand All @@ -110,6 +116,7 @@ def test_rmsnorm(M, N):
rms_norm = torch.nn.RMSNorm(N, device='cuda')
y_torch = rms_norm(x)

print(f"y_triton={y_triton}")
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)


Expand Down

0 comments on commit 44e9360

Please sign in to comment.