diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 41176c801fb..e7b140edf8e 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -8,7 +8,7 @@ #include "dispatch_utils.h" #ifdef USE_ROCM - #include "quantization/fp8/amd/hip_float8.h" + #include "quantization/fp8/amd/quant_utils.cuh" #endif namespace vllm { @@ -48,7 +48,10 @@ __global__ void scaled_act_and_mul_kernel( const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); float r = ACT_FN(x) * y * scale; out[token_idx * d + idx] = c10::Float8_e4m3fnuz( - hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); + __hip_cvt_float_to_fp8(__bfloat162float(r), + fp8::fp8_type::__default_saturation, + fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); } } #endif diff --git a/csrc/quantization/fp8/amd/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h deleted file mode 100644 index f9c80fcdec5..00000000000 --- a/csrc/quantization/fp8/amd/hip_float8.h +++ /dev/null @@ -1,137 +0,0 @@ -#pragma once - -#ifdef __HIPCC__ - #include -#else - #include - #include - #include - #include -#endif - -#include "hip_float8_impl.h" - -struct alignas(1) hip_fp8 { - struct from_bits_t {}; - HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { - return from_bits_t(); - } - uint8_t data; - - hip_fp8() = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; - explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) - : data(v) {} - -#ifdef __HIP__MI300__ - // NOTE: ON-DEVICE... always optimal bias - explicit HIP_FP8_DEVICE hip_fp8(float v) - : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} - - explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) - : hip_fp8(static_cast(v)) {} - - // Host only implementation using s/w simulation - explicit HIP_FP8_HOST -#else // __HIP__MI300__ - // both Host and DEVICE for non-MI300 using s/w simulation - explicit HIP_FP8_HOST_DEVICE -#endif // __HIP__MI300__ - hip_fp8(float v) { - data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, - true /*clip*/>(v); - } - - explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) - : hip_fp8(static_cast(v)) {} - -#ifdef __HIP__MI300__ - // upcast using device specific intrinsic - explicit inline HIP_FP8_DEVICE operator float() const { - float fval; - uint32_t i32val = static_cast(data); - - // upcast - asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" - : "=v"(fval) - : "v"(i32val)); - - return fval; - } - - explicit inline HIP_FP8_HOST operator float() const -#else // __HIP__MI300__ - explicit inline HIP_FP8_HOST_DEVICE operator float() const -#endif // __HIP__MI300__ - { - return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>( - data); - } -}; - -namespace std { -inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } -inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } -HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } -} // namespace std - -// Special operator overloading -inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { - return os << float(f8); -} - -// all + operator overloading with mixed types -// mixed types, always converts to f32, does computation in f32, and returns -// float -inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { - return (fa + float(b)); -} - -inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { - return (float(a) + fb); -} - -inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { - return hip_fp8(float(a) + float(b)); -} - -inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { - return a = hip_fp8(float(a) + float(b)); -} - -// overloading multiplication, always returns float, -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { - return float(a) * float(b); -} - -inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { - return (a * float(b)); -} - -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { - return (float(a) * b); -} - -inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { - return ((float)a * float(b)); -} - -inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { - return ((float)a * float(b)); -} - -// overloading for compare -inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { - return (a.data == b.data); -} -inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { - return (a.data != b.data); -} - -inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { - return static_cast(a) >= static_cast(b); -} -inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { - return static_cast(a) > static_cast(b); -} diff --git a/csrc/quantization/fp8/amd/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h deleted file mode 100644 index 90251c35395..00000000000 --- a/csrc/quantization/fp8/amd/hip_float8_impl.h +++ /dev/null @@ -1,316 +0,0 @@ -#pragma once - -#if defined(__HIPCC__) && \ - (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) - #define __HIP__MI300__ -#endif - -#ifdef __HIPCC__ - #define HIP_FP8_HOST_DEVICE __host__ __device__ - #define HIP_FP8_HOST __host__ - #define HIP_FP8_DEVICE __device__ -#else - #define HIP_FP8_HOST_DEVICE - #define HIP_FP8_HOST - #define HIP_FP8_DEVICE -#endif - -namespace hip_fp8_impl { - -#ifdef __HIP__MI300__ -HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) { - uint8_t i8data; - union { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // NOTE: not endian independent - } val; - - uint32_t ival = 0; - val.fval = v; - - if ((val.i32val & 0x7F800000) != - 0x7F800000) { /// propagate NAN/INF, no clipping - val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); - } - - ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, - false); // false -> WORD0 - val.i32val = ival; - i8data = val.i8val[0]; - - return i8data; -} -#endif // __HIP__MI300__ - -HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } -#if defined(__HIPCC__) || defined(__CUDA_ARCH__) -HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); } -#endif - -template -HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, - uint32_t rng = 0) { -#ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; -#else - constexpr bool is_half = false; -#endif - constexpr bool is_float = std::is_same::value; - static_assert(wm + we == 7, "wm+we==7"); - static_assert(is_half || is_float, "Only half and float can be cast to f8"); - - const int mfmt = (sizeof(T) == 4) ? 23 : 10; - uint32_t x; - if (sizeof(T) == 4) { - x = reinterpret_cast(_x); - } else { - x = reinterpret_cast(_x); - } - - uint32_t head, mantissa; - int exponent, bias; - uint32_t sign; - - if (sizeof(T) == 4) { - head = x & 0xFF800000; - mantissa = x & 0x7FFFFF; - exponent = (head >> 23) & 0xFF; - sign = head >> 31; - bias = 127; - } else { - head = x & 0xFC00; - mantissa = x & 0x3FF; - exponent = (head >> 10) & 0x1F; - sign = head >> 15; - bias = 15; - } - - uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); - - // Deal with inf and NaNs - if (negative_zero_nan) { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return 0x80; - } - } else { - // if(__hisinf(x) || __hisnan(x)) - if ((x & 0x7C00) == 0x7C00) { - return 0x80; - } - } - } else { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } else { - if ((x & 0x7C00) == 0x7C00) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } - } - if (x == 0) { - return 0; - } - - // First need to check if it is normal or denorm as there is a difference of - // implicit 1 Then need to adjust the exponent to align with the F8 exponent, - // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng - // to mantissa and truncate. And for RNE, no need to add rng. Then probably - // need to check whether there is carry and adjust exponent and mantissa again - - // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent - // bits - const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); - const int f8_denormal_act_exponent = - 1 - f8_bias; // actual exponent of f8 denormal - // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) - // f8_exponent is the converted f8 exponent with bias encoding - // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, - // the difference needs to be adjusted and mantissa shifted - int act_exponent, f8_exponent, exponent_diff; - - if (exponent == 0) { // fp32/fp16 is in denormal. - /* fp32 denormal is below 2^-127 so it is usually not a concern here, we -mostly concern fp16 here. In this case, f8 is usually in denormal. But there -could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has -exponent bias 16. It means that there are some numbers in fp16 denormal but they -are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers -where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 -(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ - act_exponent = exponent - bias + 1; - exponent_diff = - f8_denormal_act_exponent - - act_exponent; // actual exponent is exponent-bias+1 as it is denormal - } else { // fp32/fp16 is normal with implicit 1 - act_exponent = exponent - bias; - if (act_exponent <= f8_denormal_act_exponent) { - /* This is the case where fp32/fp16 is normal but it is in f8 denormal -range. For example fp8 nanoo mode, denormal exponent is -7, but if the -fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, -Therefore it needs to be adjust to -6 and mantissa shift right by 1. -So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ - exponent_diff = f8_denormal_act_exponent - act_exponent; - } else { // both fp32/fp16 and f8 are in normal range - exponent_diff = 0; // exponent_diff=0 does not mean there is no - // difference for this case, act_exponent could be - // larger. Just that it does not need shift mantissa - } - mantissa += (1 << mfmt); // Add the implicit 1 into mantissa - } - - bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == - static_cast(1 << (mfmt - wm + exponent_diff - 1)); - /* This part is a bit tricky. The judgment of whether it is a tie needs to be - done before we shift right as shift right could rip off some residual part - and make something not midpoint look like midpoint. For example, the fp16 - number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after - shift right by 4 bits, it would look like midpoint. -*/ - - if (exponent_diff > 0) { - mantissa >>= exponent_diff; - } else if (exponent_diff == -1) { - mantissa <<= -exponent_diff; - } - bool implicit_one = mantissa & (1 << mfmt); - // if there is no implicit 1, it means the f8 is denormal and need to adjust - // to denorm exponent - f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + - f8_bias - (implicit_one ? 0 : 1); - - // Now we have the exponent and mantissa adjusted - uint32_t drop_mask = (1 << (mfmt - wm)) - 1; - bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit - // that is not truncated is 1 - mantissa += - (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & - drop_mask; - - // Now we deal with overflow - if (f8_exponent == 0) { - if ((1 << mfmt) & mantissa) { - f8_exponent = 1; // denormal overflow to become normal, promote exponent - } - } else { - if ((1 << (mfmt + 1)) & mantissa) { - mantissa >>= 1; - f8_exponent++; - } - } - - mantissa >>= (mfmt - wm); - - // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); - if (f8_exponent > max_exp) { - if (clip) { - mantissa = (1 << wm) - 1; - f8_exponent = max_exp; - } else { - return signed_inf; - } - } - - if (f8_exponent == 0 && mantissa == 0) { - return negative_zero_nan ? 0 : (sign << 7); - } - mantissa &= (1 << wm) - 1; - return (sign << 7) | (f8_exponent << wm) | mantissa; -} - -template -inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) { -#ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; -#else - constexpr bool is_half = false; -#endif - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "only half and float are supported"); - - constexpr int weo = is_half ? 5 : 8; - constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); - - T fInf, fNegInf, fNaN, fNeg0; - -#ifdef __HIPCC__ - if (is_half) { - const uint16_t ihInf = 0x7C00; - const uint16_t ihNegInf = 0xFC00; - const uint16_t ihNaN = 0x7C01; - const uint16_t ihNeg0 = 0x8000; - fInf = reinterpret_cast(ihInf); - fNegInf = reinterpret_cast(ihNegInf); - fNaN = reinterpret_cast(ihNaN); - fNeg0 = reinterpret_cast(ihNeg0); - } else -#endif - if (is_float) { - const uint32_t ifInf = 0x7F800000; - const uint32_t ifNegInf = 0xFF800000; - const uint32_t ifNaN = 0x7F800001; - const uint32_t ifNeg0 = 0x80000000; - fInf = reinterpret_cast(ifInf); - fNegInf = reinterpret_cast(ifNegInf); - fNaN = reinterpret_cast(ifNaN); - fNeg0 = reinterpret_cast(ifNeg0); - } - - if (x == 0) { - return 0; - } - - uint32_t sign = x >> 7; - uint32_t mantissa = x & ((1 << wm) - 1); - int exponent = (x & 0x7F) >> wm; - if (negative_zero_nan) { - if (x == 0x80) { - return fNaN; - } - } else { - if (x == 0x80) { - return fNeg0; - } - if (exponent == ((1 << we) - 1)) { - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; - } - } - typename std::conditional::type retval; - if (we == 5 && is_half && !negative_zero_nan) { - retval = x << 8; - return reinterpret_cast(retval); - } - - const int exp_low_cutoff = - (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); - - // subnormal input - if (exponent == 0) { - // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + clz(mantissa) - (32 - wm); - mantissa <<= sh; - exponent += 1 - sh; - mantissa &= ((1 << wm) - 1); - } - exponent += exp_low_cutoff - 1; - mantissa <<= wmo - wm; - - // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) - if (exponent <= 0) { - mantissa |= 1 << wmo; - mantissa >>= 1 - exponent; - exponent = 0; - } - - if (sizeof(T) == 2) { - retval = (sign << 15) | (exponent << 10) | mantissa; - } else { - retval = (sign << 31) | (exponent << 23) | mantissa; - } - return reinterpret_cast(retval); -} - -} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index 4b77817f2df..b2196b8ed51 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -1,5 +1,5 @@ #pragma once -#include "hip_float8.h" +#include #include #include @@ -24,39 +24,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, return x; } + #if HIP_FP8_TYPE_FNUZ +using fp8_type = __hip_fp8_e4m3_fnuz; +using fp8x2_type = __hip_fp8x2_e4m3_fnuz; + #elif HIP_FP8_TYPE_OCP +using fp8_type = __hip_fp8_e4m3; +using fp8x2_type = __hip_fp8x2_e4m3; + #endif + // fp8 -> half template <> __inline__ __device__ uint16_t vec_conversion(const uint8_t& a) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8); - return res.x; + return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x; } // fp8x2 -> half2 template <> __inline__ __device__ uint32_t vec_conversion(const uint16_t& a) { - #if defined(__HIP__MI300__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); union { __half2_raw h2r; uint32_t ui32; } tmp; - tmp.h2r.x.data = f2[0]; - tmp.h2r.y.data = f2[1]; + tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); return tmp.ui32; - #else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = vec_conversion(static_cast(a)); - tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); - return tmp.u32; - #endif } // fp8x4 -> half2x2 @@ -89,9 +81,9 @@ using __nv_bfloat16 = __hip_bfloat16; template <> __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f); + fp8_type f8; + f8.__x = a; + return __float2bfloat16(static_cast(f8)); } using __nv_bfloat162 = __hip_bfloat162; @@ -133,26 +125,18 @@ __inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { // fp8 -> float template <> __inline__ __device__ float vec_conversion(const uint8_t& a) { - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8); + fp8_type f8; + f8.__x = a; + return static_cast(f8); } // fp8x2 -> float2 template <> __inline__ __device__ float2 vec_conversion(const uint16_t& a) { - #if defined(__HIP__MI300__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0]; - res.y = f2[1]; - return res; - #else - float2 res; - res.x = vec_conversion(static_cast(a)); - res.y = vec_conversion(static_cast(a >> 8U)); - return res; - #endif + fp8x2_type f8x2; + f8x2.__x = a; + return static_cast(f8x2); } // fp8x4 -> float4 @@ -165,6 +149,15 @@ vec_conversion(const uint32_t& a) { return res; } +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +vec_conversion(const uint32_t& a) { + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + // fp8x8 -> float8 template <> __inline__ __device__ Float8_ vec_conversion(const uint2& a) { @@ -185,33 +178,36 @@ __inline__ __device__ uint8_t vec_conversion(const uint16_t& a) { __half_raw tmp; tmp.x = a; + return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} - hip_fp8 f8{static_cast(tmp.data)}; - return f8.data; +template <> +__inline__ __device__ uint16_t +vec_conversion(const uint32_t& a) { + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, + fp8_type::__default_interpret); } // bf16 -> fp8 template <> __inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) { - hip_fp8 res{__bfloat162float(a)}; - return res.data; + return __hip_cvt_float_to_fp8(__bfloat162float(a), + fp8_type::__default_saturation, + fp8_type::__default_interpret); } // float -> fp8 template <> __inline__ __device__ uint8_t vec_conversion(const float& a) { - hip_fp8 f8(a); - return f8.data; -} - -// fp8x4 -> float4 -template <> -__inline__ __device__ float4 -vec_conversion(const uint32_t& a) { - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; + return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation, + fp8_type::__default_interpret); } // float2 -> half2 @@ -303,79 +299,15 @@ vec_conversion(const Float8_& a) { */ -// fp8 -> half -template <> -__inline__ __device__ uint16_t -scaled_vec_conversion(const uint8_t& a, float scale) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8) * scale; - return res.x; -} - -// fp8x2 -> half2 -template <> -__inline__ __device__ uint32_t -scaled_vec_conversion(const uint16_t& a, float scale) { - #if defined(__HIP__MI300__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0] * scale; - tmp.h2r.y.data = f2[1] * scale; - return tmp.ui32; - #else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = - scaled_vec_conversion(static_cast(a), scale); - tmp.u16[1] = scaled_vec_conversion( - static_cast(a >> 8U), scale); - return tmp.u32; - #endif -} - -// fp8x4 -> half2x2 -template <> -__inline__ __device__ uint2 -scaled_vec_conversion(const uint32_t& a, float scale) { - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); - tmp.u32[1] = - scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return tmp.u32x2; -} - -// fp8x8 -> half2x4 -template <> -__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, - float scale) { - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = scaled_vec_conversion(a.x, scale); - tmp.u64[1] = scaled_vec_conversion(a.y, scale); - return tmp.u64x2; -} - using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f * scale); + fp8_type f8; + f8.__x = a; + return __float2bfloat16(static_cast(f8) * scale); } // fp8x2 -> __nv_bfloat162 @@ -420,27 +352,18 @@ scaled_vec_conversion(const uint2& a, float scale) { template <> __inline__ __device__ float scaled_vec_conversion( const uint8_t& a, float scale) { - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8) * scale; + fp8_type f8; + f8.__x = a; + return static_cast(f8) * scale; } // fp8x2 -> float2 template <> __inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, float scale) { - #if defined(__HIP__MI300__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0] * scale; - res.y = f2[1] * scale; - return res; - #else - float2 res; - res.x = scaled_vec_conversion(static_cast(a), scale); - res.y = scaled_vec_conversion(static_cast(a >> 8U), - scale); - return res; - #endif + fp8x2_type f8x2; + f8x2.__x = a; + return static_cast(f8x2) * scale; } // fp8x4 -> float4 @@ -476,56 +399,82 @@ scaled_vec_conversion(const uint2& a, float scale) { return res; } +// fp8 -> half +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, float scale) { + __half_raw res; + res.data = scaled_vec_conversion(a, scale); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint16_t& a, float scale) { + __half2_raw h2r = + __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); + tmp.h2r.x.data *= scale; + tmp.h2r.y.data *= scale; + return tmp.ui32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, + float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; +} + // half -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, float scale) { __half_raw tmp; tmp.x = a; - - hip_fp8 f8{static_cast(tmp.data / scale)}; - return f8.data; + tmp.data /= scale; + return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, + fp8_type::__default_interpret); } // halfx2 -> fp8x2 template <> __inline__ __device__ uint16_t scaled_vec_conversion(const uint32_t& a, float scale) { - #ifdef __HIP__MI300__ - union { - uint32_t ui32; - __half2_raw h2r; - } tmp; - tmp.ui32 = a; - - union { - uint32_t ui32; - float f; - } f1, f2; - f1.f = tmp.h2r.x.data / scale; - f2.f = tmp.h2r.y.data / scale; - if ((f1.ui32 & 0x7F800000) != 0x7F800000) { - f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); - } - if ((f2.ui32 & 0x7F800000) != 0x7F800000) { - f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); - } - return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); - #else union { uint32_t ui32; __half2_raw h2r; } tmp; tmp.ui32 = a; - - union { - uint8_t ui8[2]; - uint16_t ui16; - } res; - res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); - res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); - return res.ui16; - #endif + tmp.h2r.x.data /= scale; + tmp.h2r.y.data /= scale; + return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, + fp8_type::__default_interpret); } // half2x2 -> fp8x4 @@ -560,8 +509,9 @@ __inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, template <> __inline__ __device__ uint8_t scaled_vec_conversion( const __nv_bfloat16& a, float scale) { - hip_fp8 res{__bfloat162float(a) / scale}; - return res.data; + return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale, + fp8_type::__default_saturation, + fp8_type::__default_interpret); } // bf16x2 -> fp8x2 @@ -604,37 +554,16 @@ scaled_vec_conversion(const bf16_8_t& a, float scale) { template <> __inline__ __device__ uint8_t scaled_vec_conversion(const float& a, float scale) { - hip_fp8 f8(a); - return f8.data; + return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation, + fp8_type::__default_interpret); } // floatx2 -> fp8x2 template <> __inline__ __device__ uint16_t scaled_vec_conversion(const float2& a, float scale) { - #ifdef __HIP__MI300__ - union { - uint32_t ui32; - float f; - } f1, f2; - f1.f = a.x / scale; - f2.f = a.y / scale; - if ((f1.ui32 & 0x7F800000) != 0x7F800000) { - f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); - } - if ((f2.ui32 & 0x7F800000) != 0x7F800000) { - f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); - } - return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); - #else - union { - uint8_t ui8[2]; - uint16_t ui16; - } tmp; - tmp.ui8[0] = scaled_vec_conversion(a.x, scale); - tmp.ui8[1] = scaled_vec_conversion(a.y, scale); - return tmp.ui16; - #endif + return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation, + fp8_type::__default_interpret); } // floatx4 -> fp8x4 diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index bdfa43a80e7..ffcb3177b61 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -12,7 +12,7 @@ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); #else #include - #include "amd/hip_float8.h" + #include "amd/quant_utils.cuh" using FP8_TYPE = c10::Float8_e4m3fnuz; // Using the default max value from pytorch (240.0) will cause accuracy // issue when running dynamic quantization. Here use 224.0f for rocm. @@ -47,8 +47,10 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, return static_cast(r); #else // Use hardware cvt instruction for fp8 on rocm - return c10::Float8_e4m3fnuz(hip_fp8(r).data, - c10::Float8_e4m3fnuz::from_bits()); + return c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation, + fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); #endif } diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 01b29428131..5e707251828 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -17,10 +17,12 @@ #include #include #include +#include #include #include "cuda_compat.h" #include + #include "../attention/dtype_fp8.cuh" #include "../quantization/fp8/amd/quant_utils.cuh" @@ -1518,7 +1520,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( acc *= out_scale; OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; if constexpr (std::is_same::value) { - out_ptr[threadIdx.x] = hip_fp8(acc).data; + out_ptr[threadIdx.x] = + __hip_cvt_float_to_fp8(acc, vllm::fp8::fp8_type::__default_saturation, + vllm::fp8::fp8_type::__default_interpret); } else { out_ptr[threadIdx.x] = from_float(acc); } diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 6f909b6803d..f10df8012df 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -151,19 +151,20 @@ def test_reshape_and_cache( device) key_cache, value_cache = key_caches[0], value_caches[0] + # Using default kv_scale + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) + # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache) + ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item()) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache) + ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item()) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() - # Using default kv_scale - k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) - # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, @@ -174,9 +175,9 @@ def test_reshape_and_cache( if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(result_key_cache, key_cache) + ops.convert_fp8(result_key_cache, key_cache, k_scale.item()) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(result_value_cache, value_cache) + ops.convert_fp8(result_value_cache, value_cache, v_scale.item()) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -260,15 +261,16 @@ def test_reshape_and_cache_flash( del key_caches del value_caches - k_scale = (key.amax() / 256.0).to(torch.float32) - v_scale = (value.amax() / 256.0).to(torch.float32) + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) + ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(), + kv_cache_dtype) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache, v_scale, + ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(), kv_cache_dtype) else: cloned_key_cache = key_cache.clone()