-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtest_double_row_matmul.py
132 lines (100 loc) · 5.25 KB
/
test_double_row_matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Copyright (c) 2025, Amazon.com. All Rights Reserved
"""
import pytest
from nki_samples.reference.double_row_matmul import quantized_double_row_matmul
from neuronxcc.nki import benchmark, baremetal, simulate_kernel
import neuronxcc.nki.language as nl
import numpy as np
xfail = pytest.mark.arch_specific_xfail
bench_func = benchmark(warmup=5, iters=10)(quantized_double_row_matmul)
def reshape(matrix):
"""
Interleaves every [128,512] tiles from every 2 tile rows.
A [K,N] matrix is reshaped into [K//2, 2*N] where K must be divisible by 128 and
N must be divisible by 512.
E.g. if Tij is the (i,j)-th tile and assuming a matrix with 4x4 [128,512] tiles,
the reshaped matrix looks as follows
# T11 T12 T13 T14
# T21 T22 T23 T24 reshape T11 T21 T12 T22 T13 T23 T14 T24
# T31 T32 T33 T34 --------> T21 T41 T22 T42 T23 T43 T24 T44
# T41 T42 T43 T44
"""
K, N = matrix.shape
TILE_K = 128
TILE_N = 512
assert K % TILE_K == 0
assert N % TILE_N == 0
result = np.zeros((K // 2, 2 * N))
for k in range(0, K // TILE_K, 2):
for n in range(N // TILE_N):
# Get 2 tiles in the same tile column and consecutive tile rows.
tile1 = matrix[k * TILE_K:(k + 1) * TILE_K, n * TILE_N:(n + 1) * TILE_N]
tile2 = matrix[(k + 1) * TILE_K:(k + 2) * TILE_K, n * TILE_N:(n + 1) * TILE_N]
result[(k // 2) * TILE_K:(k // 2 + 1) * TILE_K, n * TILE_N * 2:n * TILE_N * 2 + TILE_N] = tile1
result[(k//2) * TILE_K:(k // 2 + 1) * TILE_K, n * TILE_N * 2 + TILE_N:(n + 1) * TILE_N * 2] = tile2
# Place the 2 tiles in the same tile row side by side.
result[(k // 2) * TILE_K:(k // 2 + 1) * TILE_K, n * TILE_N * 2:n * TILE_N * 2+TILE_N] = tile1
result[(k // 2) * TILE_K:(k // 2 + 1) * TILE_K, n * TILE_N * 2 + TILE_N:n * TILE_N * 2 + TILE_N + TILE_N] = tile2
return result
def column_wise_quantize(matrix):
"""
Quantizes a matrix.
Returns a column-wise scale broadcasted to (128, matrix.shape[1]) and the quantized matrix.
"""
FP8_RANGE = 240
column_wise_max = np.max(np.abs(matrix), axis=0, keepdims=True)
column_wise_scale = column_wise_max / FP8_RANGE
matrix_quantized = matrix / column_wise_scale
column_wise_scale = np.broadcast_to(column_wise_scale, (128, matrix.shape[1]))
return column_wise_scale, matrix_quantized
class TestDoubleRowMatmul:
@xfail(fail=['trn1'])
@pytest.mark.parametrize("M, K, N, dtype, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K, max_p99_latency", [
[512, 16 * 1024, 1024, nl.bfloat16, 2, 2, 16, 320],
])
def test_double_row_matmul_perf(self, M, K, N, dtype, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K, max_p99_latency):
# Initializing random inputs
lhs = np.random.rand(M, K)
rhs = np.random.rand(K, N)
# Quantizing rhs
rhs_scale, rhs_quantized = column_wise_quantize(rhs)
rhs_quantized_reshaped = reshape(rhs_quantized)
# Casting to the correct data type (rhs is pre-quantized, thus casted to FP8)
lhs = nl.static_cast(lhs, dtype)
rhs_scale = nl.static_cast(rhs_scale, dtype)
rhs_quantized_reshaped = nl.static_cast(rhs_quantized_reshaped, nl.float8_e4m3)
# Latency checks
bench_func(lhs, rhs_quantized_reshaped, rhs_scale, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K)
latency_res = bench_func.benchmark_result.nc_latency
p99_latency = latency_res.get_latency_percentile(99)
assert p99_latency <= max_p99_latency
@xfail(fail=['trn1'])
@pytest.mark.simulation
@pytest.mark.parametrize("M, K, N, dtype, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K", [
[512, 16 * 1024, 1024, nl.bfloat16, 2, 2, 16],
[512, 16 * 1024, 1024, nl.bfloat16, 4, 1, 32],
[512, 16 * 1024, 1024, nl.bfloat16, 4, 2, 128],
])
def test_double_row_matmul_numerical(self, simulation_only, M, K, N, dtype, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K):
# Initializing random inputs
lhs = np.random.rand(M, K)
rhs = np.random.rand(K, N)
# Correct CPU results
result_golden = np.matmul(lhs, rhs)
# Quantizing rhs
rhs_scale, rhs_quantized = column_wise_quantize(rhs)
rhs_quantized_reshaped = reshape(rhs_quantized)
# Casting to the correct data type (rhs is pre-quantized, thus casted to FP8)
lhs = nl.static_cast(lhs, dtype)
rhs_scale = nl.static_cast(rhs_scale, dtype)
rhs_quantized_reshaped = nl.static_cast(rhs_quantized_reshaped, nl.float8_e4m3)
# Numerical accuracy checks
numeric_func = baremetal(quantized_double_row_matmul)
if simulation_only:
result_nki = simulate_kernel(numeric_func, lhs, rhs_quantized_reshaped, rhs_scale, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K)
else:
result_nki = numeric_func(lhs, rhs_quantized_reshaped, rhs_scale, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K)
# Casting result_nki from dtype BF16 back to FP32 to compare the NumPy and NKI results
result_nki = result_nki.astype(np.float32)
assert np.allclose(result_golden, result_nki, rtol=2e-2)