|
| 1 | +#include "quant_utils.h" |
| 2 | + |
| 3 | +template <typename T, int VecSize> |
| 4 | +struct __align__(sizeof(T) * VecSize) VecType { |
| 5 | + T val[VecSize]; |
| 6 | + __host__ __device__ inline T& operator[](size_t i) { return val[i]; } |
| 7 | + __host__ __device__ inline const T& operator[](size_t i) const { |
| 8 | + return val[i]; |
| 9 | + } |
| 10 | +}; |
| 11 | + |
| 12 | +template <typename OutT, int VecSize> |
| 13 | +__device__ void StoreQuantResult(OutT* out, |
| 14 | + const OutT shm[][129], |
| 15 | + int64_t L, |
| 16 | + int64_t C) { |
| 17 | + for (int i = 0; i < 4; i++) { |
| 18 | + int64_t idx_n = blockIdx.z; |
| 19 | + int64_t idx_h = |
| 20 | + static_cast<int64_t>(blockIdx.x) * 128 + (i * 32 + threadIdx.y); |
| 21 | + int64_t idx_l = static_cast<int64_t>(blockIdx.y) * 128; |
| 22 | + int64_t idx = (idx_n * C + idx_h) * L + idx_l; |
| 23 | + |
| 24 | + for (int j = 0; j < 4; j += VecSize) { |
| 25 | + int64_t block_offset = j * 32 + threadIdx.x * VecSize; |
| 26 | + if (idx_l + block_offset < L) { |
| 27 | + using StoreT = VecType<OutT, VecSize>; |
| 28 | + StoreT data; |
| 29 | + for (int k = 0; k < VecSize; k++) { |
| 30 | + data[k] = shm[i * 32 + threadIdx.y][block_offset + k]; |
| 31 | + } |
| 32 | + *reinterpret_cast<StoreT*>(out + (idx + block_offset)) = data; |
| 33 | + } |
| 34 | + } |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +__device__ int GetVecSize(int64_t n) { |
| 39 | + return n % 4 == 0 ? 4 : (n % 2 == 0 ? 2 : 1); |
| 40 | +} |
| 41 | + |
| 42 | +template <typename OutT> |
| 43 | +__global__ void __launch_bounds__(1024) |
| 44 | + FusedTransposeQuantKernel(const phi::bfloat16* __restrict__ X, |
| 45 | + OutT* __restrict__ out, |
| 46 | + float* __restrict__ scale, |
| 47 | + int64_t L, |
| 48 | + int64_t C) { |
| 49 | + __shared__ OutT shm[128][129]; |
| 50 | + |
| 51 | + int64_t offset_n = blockIdx.z; |
| 52 | + int64_t offset_l = static_cast<int64_t>(blockIdx.y) * 128 + threadIdx.y; |
| 53 | + int64_t offset_h = static_cast<int64_t>(blockIdx.x) * 128 + threadIdx.x * 4; |
| 54 | + int64_t offset = (offset_n * L + offset_l) * C + offset_h; |
| 55 | + |
| 56 | + for (int i = 0; i < 4; i++) { |
| 57 | + int64_t idx = offset + i * 32 * C; |
| 58 | + |
| 59 | + if (offset_l + i * 32 < L) { |
| 60 | + // Load x [N, L, C], 128 elements per warp |
| 61 | + using LoadT = VecType<__nv_bfloat16, 4>; |
| 62 | + const LoadT input = *reinterpret_cast<const LoadT*>(X + idx); |
| 63 | + |
| 64 | + float input_fp32[4]; |
| 65 | + for (int j = 0; j < 4; j++) { |
| 66 | + input_fp32[j] = static_cast<float>(input[j]); |
| 67 | + } |
| 68 | + |
| 69 | + // Find the maximum of every 128 elements in dim C. |
| 70 | + float max = fabsf(input_fp32[0]); |
| 71 | + for (int j = 1; j < 4; j++) { |
| 72 | + max = fmaxf(max, fabsf(input_fp32[j])); |
| 73 | + } |
| 74 | + for (int offset = 16; offset > 0; offset >>= 1) { |
| 75 | + max = fmaxf(max, __shfl_xor_sync(0xffffffff, max, offset)); |
| 76 | + } |
| 77 | + |
| 78 | + // Compute the scale of max. |
| 79 | + const float scale_on_fp32_to_outputT = |
| 80 | + ComputeScale<__nv_bfloat16, OutT>(max, 0.0f); |
| 81 | + const float scale_on_fp8_to_inputT = __frcp_rn(scale_on_fp32_to_outputT); |
| 82 | + |
| 83 | + // Scale X and transpose into shared memory. |
| 84 | + for (int j = 0; j < 4; j++) { |
| 85 | + float output_scaled = input_fp32[j] * scale_on_fp8_to_inputT; |
| 86 | + shm[threadIdx.x * 4 + j][i * 32 + threadIdx.y] = |
| 87 | + static_cast<OutT>(output_scaled); |
| 88 | + } |
| 89 | + |
| 90 | + // Store scale [N, L, C/128]. |
| 91 | + if (threadIdx.x == 0) { |
| 92 | + scale[idx / 128] = scale_on_fp32_to_outputT; |
| 93 | + } |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + __syncthreads(); |
| 98 | + |
| 99 | + // Store y [N, C, L] |
| 100 | + int vec_size = GetVecSize(L); |
| 101 | + if (vec_size == 4) { |
| 102 | + StoreQuantResult<OutT, 4>(out, shm, L, C); |
| 103 | + } else if (vec_size == 2) { |
| 104 | + StoreQuantResult<OutT, 2>(out, shm, L, C); |
| 105 | + } else { |
| 106 | + StoreQuantResult<OutT, 1>(out, shm, L, C); |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +/** |
| 111 | + * Doing fused transpose and quant. |
| 112 | + * |
| 113 | + * Inputs: |
| 114 | + * X : [*, L, C], bfloat16 |
| 115 | + * |
| 116 | + * Outputs: |
| 117 | + * out : [*, C, L], float8 |
| 118 | + * scale: [*, L, C/128], float32 |
| 119 | + * |
| 120 | + * Equivalent python code: |
| 121 | + * def fused_transpose_quant(x): |
| 122 | + * N, L, C = paddle.shape(x) |
| 123 | + * x = x.reshape([N, L, C // 128, 128]).astype('float32') |
| 124 | + * scale = ComputeScale(x.abs().max(axis=-1)) |
| 125 | + * x = (x / scale.unsqueeze(-1)).astype('float8_e4m3fn') |
| 126 | + * out = x.reshape([N, L, C]).transpose(0, 2, 1) |
| 127 | + * return out, scale |
| 128 | + */ |
| 129 | +std::vector<paddle::Tensor> fused_transpose_quant(const paddle::Tensor& X) { |
| 130 | + PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16); |
| 131 | + |
| 132 | + std::vector<int64_t> shape = X.shape(); |
| 133 | + PD_CHECK(shape.size() >= 2); |
| 134 | + |
| 135 | + int64_t seq_len = shape[shape.size() - 2]; |
| 136 | + int64_t hidden_size = shape[shape.size() - 1]; |
| 137 | + int64_t batch_size = X.numel() / (seq_len * hidden_size); |
| 138 | + |
| 139 | + PADDLE_ENFORCE_LE(batch_size, |
| 140 | + 65535, |
| 141 | + common::errors::InvalidArgument( |
| 142 | + "The batch size (shape[0:-2]) of Input(X) must be no " |
| 143 | + "larger than 65535.")); |
| 144 | + PADDLE_ENFORCE_LE(seq_len, |
| 145 | + 65535 * 128, |
| 146 | + common::errors::InvalidArgument( |
| 147 | + "The sequence length (shape[-2]) of Input(X) must be " |
| 148 | + "no larger than 65535 * 128.")); |
| 149 | + PADDLE_ENFORCE_LE(hidden_size, |
| 150 | + ((1L << 31) - 1) * 128, |
| 151 | + common::errors::InvalidArgument( |
| 152 | + "The hidden size (shape[-1]) of Input(X) must be no " |
| 153 | + "larger than (2^31 - 1) * 128.")); |
| 154 | + PADDLE_ENFORCE_EQ(hidden_size % 128, |
| 155 | + 0, |
| 156 | + common::errors::InvalidArgument( |
| 157 | + "The hidden size (shape[-1]) of Input(X) must be " |
| 158 | + "multiple of 128.")); |
| 159 | + |
| 160 | + // Allocate for out and scale |
| 161 | + std::vector<int64_t> out_shape = shape; |
| 162 | + out_shape[shape.size() - 2] = hidden_size; |
| 163 | + out_shape[shape.size() - 1] = seq_len; |
| 164 | + paddle::Tensor out = |
| 165 | + paddle::empty(out_shape, paddle::DataType::FLOAT8_E4M3FN, X.place()); |
| 166 | + |
| 167 | + std::vector<int64_t> scale_shape = shape; |
| 168 | + scale_shape[shape.size() - 1] = hidden_size / 128; |
| 169 | + paddle::Tensor scale = |
| 170 | + paddle::empty(scale_shape, paddle::DataType::FLOAT32, X.place()); |
| 171 | + |
| 172 | + // Skip 0-size |
| 173 | + if (batch_size == 0 || seq_len == 0 || hidden_size == 0) { |
| 174 | + return {out, scale}; |
| 175 | + } |
| 176 | + |
| 177 | + // Launch kernel |
| 178 | + dim3 grid(hidden_size / 128, (seq_len + 127) / 128, batch_size); |
| 179 | + dim3 block(32, 32); |
| 180 | + |
| 181 | + FusedTransposeQuantKernel<<<grid, block>>>(X.data<phi::bfloat16>(), |
| 182 | + out.data<phi::float8_e4m3fn>(), |
| 183 | + scale.data<float>(), |
| 184 | + seq_len, |
| 185 | + hidden_size); |
| 186 | + |
| 187 | + return {out, scale}; |
| 188 | +} |
| 189 | + |
| 190 | +PD_BUILD_OP(fused_transpose_quant) |
| 191 | + .Inputs({"X"}) |
| 192 | + .Outputs({"output", "scale"}) |
| 193 | + .SetKernelFn(PD_KERNEL(fused_transpose_quant)); |
0 commit comments