Skip to content

Commit a9fbd10

Browse files
committed
Add fused_transpose_quant op
1 parent 670cbd9 commit a9fbd10

File tree

3 files changed

+231
-0
lines changed

3 files changed

+231
-0
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#include "quant_utils.h"
2+
3+
template <typename T, int VecSize>
4+
struct __align__(sizeof(T) * VecSize) VecType {
5+
T val[VecSize];
6+
__host__ __device__ inline T& operator[](size_t i) { return val[i]; }
7+
__host__ __device__ inline const T& operator[](size_t i) const {
8+
return val[i];
9+
}
10+
};
11+
12+
template <typename OutT, int VecSize>
13+
__device__ void StoreQuantResult(OutT* out,
14+
const OutT shm[][129],
15+
int64_t L,
16+
int64_t C) {
17+
for (int i = 0; i < 4; i++) {
18+
int64_t idx_n = blockIdx.z;
19+
int64_t idx_h =
20+
static_cast<int64_t>(blockIdx.x) * 128 + (i * 32 + threadIdx.y);
21+
int64_t idx_l = static_cast<int64_t>(blockIdx.y) * 128;
22+
int64_t idx = (idx_n * C + idx_h) * L + idx_l;
23+
24+
for (int j = 0; j < 4; j += VecSize) {
25+
int64_t block_offset = j * 32 + threadIdx.x * VecSize;
26+
if (idx_l + block_offset < L) {
27+
using StoreT = VecType<OutT, VecSize>;
28+
StoreT data;
29+
for (int k = 0; k < VecSize; k++) {
30+
data[k] = shm[i * 32 + threadIdx.y][block_offset + k];
31+
}
32+
*reinterpret_cast<StoreT*>(out + (idx + block_offset)) = data;
33+
}
34+
}
35+
}
36+
}
37+
38+
__device__ int GetVecSize(int64_t n) {
39+
return n % 4 == 0 ? 4 : (n % 2 == 0 ? 2 : 1);
40+
}
41+
42+
template <typename OutT>
43+
__global__ void __launch_bounds__(1024)
44+
FusedTransposeQuantKernel(const phi::bfloat16* __restrict__ X,
45+
OutT* __restrict__ out,
46+
float* __restrict__ scale,
47+
int64_t L,
48+
int64_t C) {
49+
__shared__ OutT shm[128][129];
50+
51+
int64_t offset_n = blockIdx.z;
52+
int64_t offset_l = static_cast<int64_t>(blockIdx.y) * 128 + threadIdx.y;
53+
int64_t offset_h = static_cast<int64_t>(blockIdx.x) * 128 + threadIdx.x * 4;
54+
int64_t offset = (offset_n * L + offset_l) * C + offset_h;
55+
56+
for (int i = 0; i < 4; i++) {
57+
int64_t idx = offset + i * 32 * C;
58+
59+
if (offset_l + i * 32 < L) {
60+
// Load x [N, L, C], 128 elements per warp
61+
using LoadT = VecType<__nv_bfloat16, 4>;
62+
const LoadT input = *reinterpret_cast<const LoadT*>(X + idx);
63+
64+
float input_fp32[4];
65+
for (int j = 0; j < 4; j++) {
66+
input_fp32[j] = static_cast<float>(input[j]);
67+
}
68+
69+
// Find the maximum of every 128 elements in dim C.
70+
float max = fabsf(input_fp32[0]);
71+
for (int j = 1; j < 4; j++) {
72+
max = fmaxf(max, fabsf(input_fp32[j]));
73+
}
74+
for (int offset = 16; offset > 0; offset >>= 1) {
75+
max = fmaxf(max, __shfl_xor_sync(0xffffffff, max, offset));
76+
}
77+
78+
// Compute the scale of max.
79+
const float scale_on_fp32_to_outputT =
80+
ComputeScale<__nv_bfloat16, OutT>(max, 0.0f);
81+
const float scale_on_fp8_to_inputT = __frcp_rn(scale_on_fp32_to_outputT);
82+
83+
// Scale X and transpose into shared memory.
84+
for (int j = 0; j < 4; j++) {
85+
float output_scaled = input_fp32[j] * scale_on_fp8_to_inputT;
86+
shm[threadIdx.x * 4 + j][i * 32 + threadIdx.y] =
87+
static_cast<OutT>(output_scaled);
88+
}
89+
90+
// Store scale [N, L, C/128].
91+
if (threadIdx.x == 0) {
92+
scale[idx / 128] = scale_on_fp32_to_outputT;
93+
}
94+
}
95+
}
96+
97+
__syncthreads();
98+
99+
// Store y [N, C, L]
100+
int vec_size = GetVecSize(L);
101+
if (vec_size == 4) {
102+
StoreQuantResult<OutT, 4>(out, shm, L, C);
103+
} else if (vec_size == 2) {
104+
StoreQuantResult<OutT, 2>(out, shm, L, C);
105+
} else {
106+
StoreQuantResult<OutT, 1>(out, shm, L, C);
107+
}
108+
}
109+
110+
/**
111+
* Doing fused transpose and quant.
112+
*
113+
* Inputs:
114+
* X : [*, L, C], bfloat16
115+
*
116+
* Outputs:
117+
* out : [*, C, L], float8
118+
* scale: [*, L, C/128], float32
119+
*
120+
* Equivalent python code:
121+
* def fused_transpose_quant(x):
122+
* N, L, C = paddle.shape(x)
123+
* x = x.reshape([N, L, C // 128, 128]).astype('float32')
124+
* scale = ComputeScale(x.abs().max(axis=-1))
125+
* x = (x / scale.unsqueeze(-1)).astype('float8_e4m3fn')
126+
* out = x.reshape([N, L, C]).transpose(0, 2, 1)
127+
* return out, scale
128+
*/
129+
std::vector<paddle::Tensor> fused_transpose_quant(const paddle::Tensor& X) {
130+
PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16);
131+
132+
std::vector<int64_t> shape = X.shape();
133+
PD_CHECK(shape.size() >= 2);
134+
135+
int64_t seq_len = shape[shape.size() - 2];
136+
int64_t hidden_size = shape[shape.size() - 1];
137+
int64_t batch_size = X.numel() / (seq_len * hidden_size);
138+
139+
PADDLE_ENFORCE_LE(batch_size,
140+
65535,
141+
common::errors::InvalidArgument(
142+
"The batch size (shape[0:-2]) of Input(X) must be no "
143+
"larger than 65535."));
144+
PADDLE_ENFORCE_LE(seq_len,
145+
65535 * 128,
146+
common::errors::InvalidArgument(
147+
"The sequence length (shape[-2]) of Input(X) must be "
148+
"no larger than 65535 * 128."));
149+
PADDLE_ENFORCE_LE(hidden_size,
150+
((1L << 31) - 1) * 128,
151+
common::errors::InvalidArgument(
152+
"The hidden size (shape[-1]) of Input(X) must be no "
153+
"larger than (2^31 - 1) * 128."));
154+
PADDLE_ENFORCE_EQ(hidden_size % 128,
155+
0,
156+
common::errors::InvalidArgument(
157+
"The hidden size (shape[-1]) of Input(X) must be "
158+
"multiple of 128."));
159+
160+
// Allocate for out and scale
161+
std::vector<int64_t> out_shape = shape;
162+
out_shape[shape.size() - 2] = hidden_size;
163+
out_shape[shape.size() - 1] = seq_len;
164+
paddle::Tensor out =
165+
paddle::empty(out_shape, paddle::DataType::FLOAT8_E4M3FN, X.place());
166+
167+
std::vector<int64_t> scale_shape = shape;
168+
scale_shape[shape.size() - 1] = hidden_size / 128;
169+
paddle::Tensor scale =
170+
paddle::empty(scale_shape, paddle::DataType::FLOAT32, X.place());
171+
172+
// Skip 0-size
173+
if (batch_size == 0 || seq_len == 0 || hidden_size == 0) {
174+
return {out, scale};
175+
}
176+
177+
// Launch kernel
178+
dim3 grid(hidden_size / 128, (seq_len + 127) / 128, batch_size);
179+
dim3 block(32, 32);
180+
181+
FusedTransposeQuantKernel<<<grid, block>>>(X.data<phi::bfloat16>(),
182+
out.data<phi::float8_e4m3fn>(),
183+
scale.data<float>(),
184+
seq_len,
185+
hidden_size);
186+
187+
return {out, scale};
188+
}
189+
190+
PD_BUILD_OP(fused_transpose_quant)
191+
.Inputs({"X"})
192+
.Outputs({"output", "scale"})
193+
.SetKernelFn(PD_KERNEL(fused_transpose_quant));

slm/model_zoo/gpt-3/external_ops/setup_fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def setup_fused_quant_ops():
4141
"fused_quanted_ops/fused_act_dequant.cu",
4242
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
4343
"fused_quanted_ops/fused_spaq.cu",
44+
"fused_quanted_ops/fused_transpose_quant.cu",
4445
],
4546
extra_compile_args={
4647
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import FusedQuantOps as FQO
2+
import numpy as np
3+
4+
import paddle
5+
6+
7+
def restore_transpose_quant(out, scale):
8+
out = out.transpose([0, 2, 1]).astype('float32')
9+
scale = paddle.repeat_interleave(scale, repeats=128, axis=-1)
10+
x = out * scale
11+
return x
12+
13+
14+
def test_fused_transpose_quant(batch_size, seq_len, hidden_size):
15+
print(batch_size, seq_len, hidden_size)
16+
x = paddle.randn([batch_size, seq_len, hidden_size], dtype='bfloat16')
17+
x = paddle.clip(x, min=-50, max=50)
18+
19+
out, scale = FQO.fused_transpose_quant(x)
20+
21+
x_fp32 = x.astype('float32')
22+
x_restored = restore_transpose_quant(out, scale)
23+
24+
np.testing.assert_allclose(
25+
x_fp32, x_restored, rtol=0.01, atol=0.3
26+
) # 存在截断误差,atol=0.3,通常在1e-6
27+
28+
29+
def run():
30+
for batch_size in [1, 4]:
31+
for seq_len in [1, 257, 2114, 4096]:
32+
for hidden_size in [2048, 7168]:
33+
test_fused_transpose_quant(batch_size, seq_len, hidden_size)
34+
35+
36+
if __name__ == "__main__":
37+
run()

0 commit comments

Comments
 (0)