diff --git a/slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_transpose_quant.cu b/slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_transpose_quant.cu new file mode 100644 index 000000000000..66f7109ee80a --- /dev/null +++ b/slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_transpose_quant.cu @@ -0,0 +1,193 @@ +#include "quant_utils.h" + +template +struct __align__(sizeof(T) * VecSize) VecType { + T val[VecSize]; + __host__ __device__ inline T& operator[](size_t i) { return val[i]; } + __host__ __device__ inline const T& operator[](size_t i) const { + return val[i]; + } +}; + +template +__device__ void StoreQuantResult(OutT* out, + const OutT shm[][129], + int64_t L, + int64_t C) { + for (int i = 0; i < 4; i++) { + int64_t idx_n = blockIdx.z; + int64_t idx_h = + static_cast(blockIdx.x) * 128 + (i * 32 + threadIdx.y); + int64_t idx_l = static_cast(blockIdx.y) * 128; + int64_t idx = (idx_n * C + idx_h) * L + idx_l; + + for (int j = 0; j < 4; j += VecSize) { + int64_t block_offset = j * 32 + threadIdx.x * VecSize; + if (idx_l + block_offset < L) { + using StoreT = VecType; + StoreT data; + for (int k = 0; k < VecSize; k++) { + data[k] = shm[i * 32 + threadIdx.y][block_offset + k]; + } + *reinterpret_cast(out + (idx + block_offset)) = data; + } + } + } +} + +__device__ int GetVecSize(int64_t n) { + return n % 4 == 0 ? 4 : (n % 2 == 0 ? 2 : 1); +} + +template +__global__ void __launch_bounds__(1024) + FusedTransposeQuantKernel(const phi::bfloat16* __restrict__ X, + OutT* __restrict__ out, + float* __restrict__ scale, + int64_t L, + int64_t C) { + __shared__ OutT shm[128][129]; + + int64_t offset_n = blockIdx.z; + int64_t offset_l = static_cast(blockIdx.y) * 128 + threadIdx.y; + int64_t offset_h = static_cast(blockIdx.x) * 128 + threadIdx.x * 4; + int64_t offset = (offset_n * L + offset_l) * C + offset_h; + + for (int i = 0; i < 4; i++) { + int64_t idx = offset + i * 32 * C; + + if (offset_l + i * 32 < L) { + // Load x [N, L, C], 128 elements per warp + using LoadT = VecType<__nv_bfloat16, 4>; + const LoadT input = *reinterpret_cast(X + idx); + + float input_fp32[4]; + for (int j = 0; j < 4; j++) { + input_fp32[j] = static_cast(input[j]); + } + + // Find the maximum of every 128 elements in dim C. + float max = fabsf(input_fp32[0]); + for (int j = 1; j < 4; j++) { + max = fmaxf(max, fabsf(input_fp32[j])); + } + for (int offset = 16; offset > 0; offset >>= 1) { + max = fmaxf(max, __shfl_xor_sync(0xffffffff, max, offset)); + } + + // Compute the scale of max. + const float scale_on_fp32_to_outputT = + ComputeScale<__nv_bfloat16, OutT>(max, 0.0f); + const float scale_on_fp8_to_inputT = __frcp_rn(scale_on_fp32_to_outputT); + + // Scale X and transpose into shared memory. + for (int j = 0; j < 4; j++) { + float output_scaled = input_fp32[j] * scale_on_fp8_to_inputT; + shm[threadIdx.x * 4 + j][i * 32 + threadIdx.y] = + static_cast(output_scaled); + } + + // Store scale [N, L, C/128]. + if (threadIdx.x == 0) { + scale[idx / 128] = scale_on_fp32_to_outputT; + } + } + } + + __syncthreads(); + + // Store y [N, C, L] + int vec_size = GetVecSize(L); + if (vec_size == 4) { + StoreQuantResult(out, shm, L, C); + } else if (vec_size == 2) { + StoreQuantResult(out, shm, L, C); + } else { + StoreQuantResult(out, shm, L, C); + } +} + +/** + * Doing fused transpose and quant. + * + * Inputs: + * X : [*, L, C], bfloat16 + * + * Outputs: + * out : [*, C, L], float8 + * scale: [*, L, C/128], float32 + * + * Equivalent python code: + * def fused_transpose_quant(x): + * N, L, C = paddle.shape(x) + * x = x.reshape([N, L, C // 128, 128]).astype('float32') + * scale = ComputeScale(x.abs().max(axis=-1)) + * x = (x / scale.unsqueeze(-1)).astype('float8_e4m3fn') + * out = x.reshape([N, L, C]).transpose(0, 2, 1) + * return out, scale + */ +std::vector fused_transpose_quant(const paddle::Tensor& X) { + PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16); + + std::vector shape = X.shape(); + PD_CHECK(shape.size() >= 2); + + int64_t seq_len = shape[shape.size() - 2]; + int64_t hidden_size = shape[shape.size() - 1]; + int64_t batch_size = X.numel() / (seq_len * hidden_size); + + PADDLE_ENFORCE_LE(batch_size, + 65535, + common::errors::InvalidArgument( + "The batch size (shape[0:-2]) of Input(X) must be no " + "larger than 65535.")); + PADDLE_ENFORCE_LE(seq_len, + 65535 * 128, + common::errors::InvalidArgument( + "The sequence length (shape[-2]) of Input(X) must be " + "no larger than 65535 * 128.")); + PADDLE_ENFORCE_LE(hidden_size, + ((1L << 31) - 1) * 128, + common::errors::InvalidArgument( + "The hidden size (shape[-1]) of Input(X) must be no " + "larger than (2^31 - 1) * 128.")); + PADDLE_ENFORCE_EQ(hidden_size % 128, + 0, + common::errors::InvalidArgument( + "The hidden size (shape[-1]) of Input(X) must be " + "multiple of 128.")); + + // Allocate for out and scale + std::vector out_shape = shape; + out_shape[shape.size() - 2] = hidden_size; + out_shape[shape.size() - 1] = seq_len; + paddle::Tensor out = + paddle::empty(out_shape, paddle::DataType::FLOAT8_E4M3FN, X.place()); + + std::vector scale_shape = shape; + scale_shape[shape.size() - 1] = hidden_size / 128; + paddle::Tensor scale = + paddle::empty(scale_shape, paddle::DataType::FLOAT32, X.place()); + + // Skip 0-size + if (batch_size == 0 || seq_len == 0 || hidden_size == 0) { + return {out, scale}; + } + + // Launch kernel + dim3 grid(hidden_size / 128, (seq_len + 127) / 128, batch_size); + dim3 block(32, 32); + + FusedTransposeQuantKernel<<>>(X.data(), + out.data(), + scale.data(), + seq_len, + hidden_size); + + return {out, scale}; +} + +PD_BUILD_OP(fused_transpose_quant) + .Inputs({"X"}) + .Outputs({"output", "scale"}) + .SetKernelFn(PD_KERNEL(fused_transpose_quant)); diff --git a/slm/model_zoo/gpt-3/external_ops/setup_fp8.py b/slm/model_zoo/gpt-3/external_ops/setup_fp8.py index 5528ffab489b..72139036f5a7 100644 --- a/slm/model_zoo/gpt-3/external_ops/setup_fp8.py +++ b/slm/model_zoo/gpt-3/external_ops/setup_fp8.py @@ -41,6 +41,7 @@ def setup_fused_quant_ops(): "fused_quanted_ops/fused_act_dequant.cu", "fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu", "fused_quanted_ops/fused_spaq.cu", + "fused_quanted_ops/fused_transpose_quant.cu", ], extra_compile_args={ "cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"], diff --git a/tests/ops/test_fused_transpose_quant.py b/tests/ops/test_fused_transpose_quant.py new file mode 100644 index 000000000000..2f0a78b9804d --- /dev/null +++ b/tests/ops/test_fused_transpose_quant.py @@ -0,0 +1,37 @@ +import FusedQuantOps as FQO +import numpy as np + +import paddle + + +def restore_transpose_quant(out, scale): + out = out.transpose([0, 2, 1]).astype('float32') + scale = paddle.repeat_interleave(scale, repeats=128, axis=-1) + x = out * scale + return x + + +def test_fused_transpose_quant(batch_size, seq_len, hidden_size): + print(batch_size, seq_len, hidden_size) + x = paddle.randn([batch_size, seq_len, hidden_size], dtype='bfloat16') + x = paddle.clip(x, min=-50, max=50) + + out, scale = FQO.fused_transpose_quant(x) + + x_fp32 = x.astype('float32') + x_restored = restore_transpose_quant(out, scale) + + np.testing.assert_allclose( + x_fp32, x_restored, rtol=0.01, atol=0.3 + ) # 存在截断误差,atol=0.3,通常在1e-6 + + +def run(): + for batch_size in [1, 4]: + for seq_len in [1, 257, 2114, 4096]: + for hidden_size in [2048, 7168]: + test_fused_transpose_quant(batch_size, seq_len, hidden_size) + + +if __name__ == "__main__": + run()