From df7f8a35865679482e32a52180d739bbcf821691 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 9 Nov 2023 23:23:25 +0000 Subject: [PATCH 001/115] changes for the FP8 ref implementation --- src/api/include/migraphx/migraphx.h | 3 +- src/include/migraphx/half.hpp | 13 + src/include/migraphx/migraphx_f8_impl.hpp | 320 +++++++++++++ src/include/migraphx/migraphx_float8.hpp | 431 ++++++++++++++++++ src/include/migraphx/shape.hpp | 4 +- src/include/migraphx/type_traits.hpp | 30 +- src/py/migraphx_py.cpp | 13 +- .../include/migraphx/gpu/device/types.hpp | 24 +- test/CMakeLists.txt | 1 + test/float_equal.cpp | 12 +- test/fp8e4m3fnuz.cpp | 233 ++++++++++ test/gpu/jit.cpp | 22 +- tools/api/migraphx.h | 7 +- 13 files changed, 1082 insertions(+), 31 deletions(-) create mode 100644 src/include/migraphx/migraphx_f8_impl.hpp create mode 100644 src/include/migraphx/migraphx_float8.hpp create mode 100644 test/fp8e4m3fnuz.cpp diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index cde517f0b28..c8467c67f30 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -44,7 +44,8 @@ m(int32_type, int32_t) \ m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ - m(uint64_type, uint64_t) + m(uint64_type, uint64_t) \ + m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz) // clang-format on #ifdef __cplusplus diff --git a/src/include/migraphx/half.hpp b/src/include/migraphx/half.hpp index 10cc7e4289c..de692dd16e3 100644 --- a/src/include/migraphx/half.hpp +++ b/src/include/migraphx/half.hpp @@ -27,6 +27,7 @@ #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -67,6 +68,18 @@ struct common_type : std::common_type // NOLINT { }; +template <> +struct common_type +{ + using type = float; +}; + +template <> +struct common_type +{ + using type = float; +}; + template <> struct common_type { diff --git a/src/include/migraphx/migraphx_f8_impl.hpp b/src/include/migraphx/migraphx_f8_impl.hpp new file mode 100644 index 00000000000..03e828483fd --- /dev/null +++ b/src/include/migraphx/migraphx_f8_impl.hpp @@ -0,0 +1,320 @@ +/* ************************************************************************ + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ + +#ifndef MIGRAPHX_FP8_IMPL_HPP +#define MIGRAPHX_FP8_IMPL_HPP +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +#endif + +#define CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) +namespace migraphx_f8_impl { +namespace detail { +template +struct conditional +{ + using type = T; +}; + +template +struct conditional +{ + using type = F; +}; + +template +inline constexpr To bit_cast(From fr) noexcept +{ + static_assert(sizeof(To) == sizeof(From)); +#if defined(__GNUC__) and !defined(__clang__) + To x = CONST_FOLD(*reinterpret_cast(&fr)); +#else + To x = __builtin_bit_cast(To, fr); +#endif + return x; +} +} // namespace detail + +template +constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) +{ + + static_assert(wm + we == 7, "wm+we==7"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + typename detail::conditional::type x; + + if constexpr(sizeof(T) == 4) + x = detail::bit_cast(_x); + else + x = detail::bit_cast(_x); + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if constexpr(sizeof(T) == 4) + { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } + else + { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if(negative_zero_nan) + { + if(sizeof(T) == 4) + { + if((x & 0x7F800000) == 0x7F800000) + return 0x80; + } + else + { + // if(__hisinf(x) || __hisnan(x)) + if((x & 0x7C00) == 0x7C00) + return 0x80; + } + } + else + { + if(sizeof(T) == 4) + { + if((x & 0x7F800000) == 0x7F800000) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + else + { + if((x & 0x7C00) == 0x7C00) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + // handle positive zero + if(x == 0) + return 0; + // handle negative zero + if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) + { + if(negative_zero_nan) + { + return 0; + } + else + { + return 0x80; + } + } + + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In +this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= f8_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + (1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. +*/ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + f8_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = + mantissa & (1 << (mfmt - wm)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(f8_exponent == 0) + { + if((1 << mfmt) & mantissa) + { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } + else + { + if((1 << (mfmt + 1)) & mantissa) + { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if(f8_exponent > max_exp) + { + if(clip) + { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } + else + { + return signed_inf; + } + } + + if(f8_exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << 7); + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +template +constexpr T cast_from_f8(uint8_t x) +{ + constexpr int weo = 8; + constexpr int wmo = 23; + + T fInf, fNegInf, fNaN, fNeg0; + uint32_t ifInf = 0x7F800000; + uint32_t ifNegInf = 0xFF800000; + uint32_t ifNaN = 0x7F800001; + uint32_t ifNeg0 = 0x80000000; + // TODO: need to change T for half but right now it would never called with half + fInf = detail::bit_cast(ifInf); + fNegInf = detail::bit_cast(ifNegInf); + fNaN = detail::bit_cast(ifNaN); + fNeg0 = detail::bit_cast(ifNeg0); + + if(x == 0) + return 0; + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if(negative_zero_nan) + { + if(x == 0x80) + return fNaN; + } + else + { + if(x == 0x80) + return fNeg0; + if(exponent == ((1 << we) - 1)) + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + typename detail::conditional::type retval; + + const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __builtin_clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if(sizeof(T) == 2) + retval = (sign << 15) | (exponent << 10) | mantissa; + else + retval = (sign << 31) | (exponent << 23) | mantissa; + return detail::bit_cast(retval); +} + +} // namespace migraphx_f8_impl +#if defined(__clang__) +#pragma clang diagnostic pop +#endif +#endif // MIGRAPHX_FP8_IMPL_HPP diff --git a/src/include/migraphx/migraphx_float8.hpp b/src/include/migraphx/migraphx_float8.hpp new file mode 100644 index 00000000000..25d8ecdfd17 --- /dev/null +++ b/src/include/migraphx/migraphx_float8.hpp @@ -0,0 +1,431 @@ +/* ************************************************************************ + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ + +#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP +#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" +#pragma clang diagnostic ignored "-Wfloat-equal" +#pragma clang diagnostic ignored "-Wmacro-redefined" +#pragma clang diagnostic ignored "-Wc++20-extensions" +#endif // __clang__ + +#ifndef MIGRAPHX_FP8_FNUZ +#define MIGRAPHX_FP8_FNUZ true +#endif // MIGRAPHX_FP8_FNUZ + +// We are clipping in down conversion by default +#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx_f8_impl { + +template +constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0); + +template +constexpr T cast_from_f8(uint8_t x); + +} // namespace migraphx_f8_impl + +#include + +namespace migraphx_fp8 { + +enum class migraphx_f8_rounding_mode +{ + standard, // standard rounding is doing RNE -- round to nearest even + stochastic +}; + +enum class f8_type +{ + bf8 = 0, // s1e5m2 + fp8 = 1 // s1e4m3 +}; + +template +class numeric_limits; + +template +struct float8 +{ + uint8_t data; + // default constructor + constexpr float8() = default; + // default copy constructor + constexpr float8(const float8& y) = default; + struct from_bits_t + { + }; + static constexpr from_bits_t from_bits() { return from_bits_t(); } + + explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} + + explicit constexpr float8(float v, + migraphx_fp8::migraphx_f8_rounding_mode rm = + migraphx_fp8::migraphx_f8_rounding_mode::standard, + uint32_t rng = 0) + { + if constexpr(T == migraphx_fp8::f8_type::fp8) + { +#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx_f8_impl:: + cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( + v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); +#else // MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx_f8_impl:: + cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( + v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); +#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING + } + else + { +#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx_f8_impl:: + cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( + v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); +#else // MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx_f8_impl:: + cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( + v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); +#endif // rocblas_F8_downcast_clipping} + } + } + + inline constexpr operator float() const + { + if constexpr(T == migraphx_fp8::f8_type::fp8) + { + return migraphx_f8_impl:: + cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data); + } // else + return migraphx_f8_impl::cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>( + data); + } + + inline constexpr bool is_zero() const + { + if constexpr(MIGRAPHX_FP8_FNUZ) + { + return data == 0x00; + } + else + { + return (data == 0x00) || (data == 0x80); + } + } + + inline constexpr bool is_nan() const + { + if constexpr(MIGRAPHX_FP8_FNUZ) + { + return data == 0x80; + } + else + { + if(T == migraphx_fp8::f8_type::bf8) + { + return (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xfd) || + (data == 0xfe) || (data == 0xff); + } + else + { + return (data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) || + (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xf9) || + (data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) || + (data == 0xfe) || (data == 0xff); + } + } + } + + inline constexpr bool is_inf() const + { + if constexpr(MIGRAPHX_FP8_FNUZ) + { + return data == 0x80; + } + else + { + if(T == migraphx_fp8::f8_type::bf8) + { + return (data == 0x7c) || (data == 0xfc); + } + else + { + return (data == 0x78) || (data == 0xf8); + } + } + } + +#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \ + constexpr float8& operator unary_op(const float8& rhs) \ + { \ + const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ + *this = static_cast(tmp); \ + return *this; \ + } \ + constexpr float8& operator unary_op(const float& rhs) \ + { \ + const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ + *this = static_cast(tmp); \ + return *this; \ + } + + MIGRAPHX_FP8_UNARY_OP(*=, *) + MIGRAPHX_FP8_UNARY_OP(-=, -) + MIGRAPHX_FP8_UNARY_OP(+=, +) + MIGRAPHX_FP8_UNARY_OP(/=, /) + + inline constexpr float8& operator=(const float8& rhs) = default; + inline constexpr float8& operator=(float8&& rhs) = default; + + inline constexpr float8& operator=(float rhs) + { + *this = static_cast(rhs); + return *this; + } + + inline constexpr bool operator==(const float8& rhs) const + { + if((rhs.is_zero() && this->is_zero()) || + (fabs(rhs - *this) < migraphx_fp8::numeric_limits>::epsilon())) + return true; + else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf()) + return false; + + return false; + } + + inline constexpr bool operator<(const float8& rhs) const + { + const auto we = static_cast(*this); + const auto them = static_cast(rhs); + return we < them; + } + + inline constexpr bool operator>(const float8& rhs) const + { + const auto we = static_cast(*this); + const auto them = static_cast(rhs); + return we > them; + } +}; + +// Special operator overloading +template +inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::float8& rhs) +{ + return os << static_cast(rhs); +} + +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ + template \ + inline constexpr U operator binary_op(const migraphx_fp8::float8& lhs, \ + const migraphx_fp8::float8& rhs) \ + { \ + return U(static_cast(lhs) binary_op static_cast(rhs)); \ + } + +// TODO: these should return floats +MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::float8) +MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::float8) +MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::float8) +MIGRAPHX_FP8_BINARY_OP(+, migraphx_fp8::float8) +// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects. +MIGRAPHX_FP8_BINARY_OP(==, bool) +MIGRAPHX_FP8_BINARY_OP(>=, bool) +MIGRAPHX_FP8_BINARY_OP(<=, bool) +MIGRAPHX_FP8_BINARY_OP(>, bool) +MIGRAPHX_FP8_BINARY_OP(<, bool) +MIGRAPHX_FP8_BINARY_OP(!=, bool) + +template +inline migraphx_fp8::float8 fabs(migraphx_fp8::float8 v) +{ + v.data = v.data & 0x7f; + return v; +} + +template +constexpr T F8_Max() +{ + return T{0x7F, T::from_bits()}; +} + +template +constexpr T F8_Lowest() +{ + return T{0xFF, T::from_bits()}; +} + +using fp8e4m3fnuz = float8; + +template <> +class numeric_limits> +{ + public: + // TODO :figure out epsilon in Hex to make it constexpr + static constexpr migraphx_fp8::float8 epsilon() + { + return migraphx_fp8::float8( + 0x28, migraphx_fp8::float8<>::from_bits()); + } + + static constexpr migraphx_fp8::float8 quiet_NaN() + { + return migraphx_fp8::float8( + MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::float8<>::from_bits()); + } + + static constexpr migraphx_fp8::float8 max() + { + return migraphx_fp8::F8_Max>(); + } + + // TODO figure out Hex value + static migraphx_fp8::float8 min() + { + return static_cast>(-1.0f) * + migraphx_fp8::F8_Max>(); + } + + static constexpr migraphx_fp8::float8 lowest() + { + return migraphx_fp8::F8_Lowest>(); + } + + static constexpr migraphx_fp8::float8 infinity() + { + return migraphx_fp8::float8( + MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::float8<>::from_bits()); + } +}; + +template <> +class numeric_limits> +{ + public: + static constexpr migraphx_fp8::float8 epsilon() + { + return migraphx_fp8::float8( + 0x34, migraphx_fp8::float8::from_bits()); + } + + static constexpr migraphx_fp8::float8 quiet_NaN() + { + return migraphx_fp8::float8( + MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d, + migraphx_fp8::float8::from_bits()); + } + + static constexpr migraphx_fp8::float8 max() + { + return static_cast>( + migraphx_fp8::F8_Max>()); + } + // TODO figure out constexpr value + static migraphx_fp8::float8 min() + { + return static_cast>(float(-1.0f)) * + migraphx_fp8::F8_Max>(); + } + static constexpr migraphx_fp8::float8 lowest() + { + return migraphx_fp8::F8_Lowest>(); + } + + static constexpr migraphx_fp8::float8 infinity() + { + return migraphx_fp8::float8( + MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c, + migraphx_fp8::float8::from_bits()); + } +}; +} // namespace migraphx_fp8 +// define numeric limits for the new data type +namespace std { +inline bool isfinite(migraphx_fp8::float8 x) // NOLINT +{ + return x.is_inf(); +} + +inline bool isfinite(migraphx_fp8::float8 x) // NOLINT +{ + return x.is_inf(); +} + +inline bool isnan(migraphx_fp8::float8 x) // NOLINT +{ + return x.is_nan(); +} + +inline bool isnan(migraphx_fp8::float8 x) // NOLINT +{ + return x.is_nan(); +} + +template <> +class numeric_limits> + : public migraphx_fp8::numeric_limits> +{ +}; + +template <> +class numeric_limits> + : public migraphx_fp8::numeric_limits> +{ +}; + +template +struct common_type : std::common_type // NOLINT +{ +}; + +template +struct common_type : std::common_type // NOLINT +{ +}; + +template <> +struct common_type +{ + using type = float; +}; + +} // namespace std +// ================================================================================================= +#if defined(__clang__) +#pragma clang diagnostic pop +#endif +#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 3cf5785087c..ef82f300226 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape m(int32_type, int32_t) \ m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ - m(uint64_type, uint64_t) + m(uint64_type, uint64_t) \ + m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz) // clang-format on #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, diff --git a/src/include/migraphx/type_traits.hpp b/src/include/migraphx/type_traits.hpp index 1512c38f203..8fc3081ef18 100644 --- a/src/include/migraphx/type_traits.hpp +++ b/src/include/migraphx/type_traits.hpp @@ -28,25 +28,45 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +#define MIGRAPHX_DETAIL_DEFINE_TRAIT(trait) \ + template \ + struct trait : std::trait \ + { \ + }; + #define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ - template \ - struct trait : std::trait \ - { \ - }; \ - \ template <> \ struct trait : std::true_type \ { \ }; +MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point); +MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic); +MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed); + +template +struct is_same : std::is_same +{ +}; + +template +struct conditional : std::conditional +{ +}; + MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx_fp8::fp8e4m3fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx_fp8::fp8e4m3fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx_fp8::fp8e4m3fnuz) + template using accumulator_type = std::conditional_t{}, diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 4b6de6c19d0..4014936ef44 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -40,7 +40,7 @@ #include #include #include - +#include #ifdef HAVE_GPU #include #endif @@ -144,6 +144,17 @@ struct npy_format_descriptor static constexpr auto name() { return _("half"); } }; +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // following: https://docs.python.org/3/library/struct.html#format-characters + return "z"; + } + static constexpr auto name() { return _("fp8e4m3fnuz"); } +}; + } // namespace detail } // namespace pybind11 diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp index 355cc4477b1..28a4b2939d7 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp @@ -146,20 +146,20 @@ __device__ __host__ T to_hip_type(T x) // Hip doens't support __fp16 inline __device__ __host__ float to_hip_type(gpu_half x) { return x; } -#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ - template \ - struct trait : std::trait \ - { \ - }; \ - \ - template <> \ - struct trait : std::true_type \ - { \ +#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ + template \ + struct trait : std::trait \ + { \ + }; \ + \ + template <> \ + struct trait : std::true_type \ + { \ }; -MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16) -MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16) -MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16) +MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16) +MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16) +MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16) } // namespace device } // namespace gpu diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3fc9ea4fac7..a6217b0ec2f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -150,6 +150,7 @@ function(test_headers PREFIX) list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp) endif() + list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_f8_impl.hpp) foreach(HEADER ${HEADERS}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) diff --git a/test/float_equal.cpp b/test/float_equal.cpp index 102ee4faf67..0ae10614708 100644 --- a/test/float_equal.cpp +++ b/test/float_equal.cpp @@ -22,6 +22,7 @@ * THE SOFTWARE. */ #include +#include #include #include "test.hpp" @@ -53,7 +54,7 @@ auto test_float_equal(T x, U y) template void test_equality() { - auto x1 = T(0.1); + auto x1 = T(0.125); auto x2 = U(0.0); auto x3 = U(1.0); EXPECT(test_float_equal(x1, x1)); @@ -71,8 +72,12 @@ void test_equality() TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); template void test_limits() @@ -110,8 +115,13 @@ void test_limits() TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); + #ifndef _WIN32 // On Windows, types int and long have the same min and max values. TEST_CASE_REGISTER(test_limits); diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp new file mode 100644 index 00000000000..02d7fb77fa5 --- /dev/null +++ b/test/fp8e4m3fnuz.cpp @@ -0,0 +1,233 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include "test.hpp" + +#include + +float fp8e4m3fnuz_to_fp32_value(uint8_t input) +{ + constexpr std::array e4m3fnuz_lut = { + 0.0f, 0.0009765625f, 0.001953125f, + 0.0029296875f, 0.00390625f, 0.0048828125f, + 0.005859375f, 0.0068359375f, 0.0078125f, + 0.0087890625f, 0.009765625f, 0.0107421875f, + 0.01171875f, 0.0126953125f, 0.013671875f, + 0.0146484375f, 0.015625f, 0.017578125f, + 0.01953125f, 0.021484375f, 0.0234375f, + 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, + 0.04296875f, 0.046875f, 0.05078125f, + 0.0546875f, 0.05859375f, 0.0625f, + 0.0703125f, 0.078125f, 0.0859375f, + 0.09375f, 0.1015625f, 0.109375f, + 0.1171875f, 0.125f, 0.140625f, + 0.15625f, 0.171875f, 0.1875f, + 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, + 0.34375f, 0.375f, 0.40625f, + 0.4375f, 0.46875f, 0.5f, + 0.5625f, 0.625f, 0.6875f, + 0.75f, 0.8125f, 0.875f, + 0.9375f, 1.0f, 1.125f, + 1.25f, 1.375f, 1.5f, + 1.625f, 1.75f, 1.875f, + 2.0f, 2.25f, 2.5f, + 2.75f, 3.0f, 3.25f, + 3.5f, 3.75f, 4.0f, + 4.5f, 5.0f, 5.5f, + 6.0f, 6.5f, 7.0f, + 7.5f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, + 16.0f, 18.0f, 20.0f, + 22.0f, 24.0f, 26.0f, + 28.0f, 30.0f, 32.0f, + 36.0f, 40.0f, 44.0f, + 48.0f, 52.0f, 56.0f, + 60.0f, 64.0f, 72.0f, + 80.0f, 88.0f, 96.0f, + 104.0f, 112.0f, 120.0f, + 128.0f, 144.0f, 160.0f, + 176.0f, 192.0f, 208.0f, + 224.0f, 240.0f, std::numeric_limits::quiet_NaN(), + -0.0009765625f, -0.001953125f, -0.0029296875f, + -0.00390625f, -0.0048828125f, -0.005859375f, + -0.0068359375f, -0.0078125f, -0.0087890625f, + -0.009765625f, -0.0107421875f, -0.01171875f, + -0.0126953125f, -0.013671875f, -0.0146484375f, + -0.015625f, -0.017578125f, -0.01953125f, + -0.021484375f, -0.0234375f, -0.025390625f, + -0.02734375f, -0.029296875f, -0.03125f, + -0.03515625f, -0.0390625f, -0.04296875f, + -0.046875f, -0.05078125f, -0.0546875f, + -0.05859375f, -0.0625f, -0.0703125f, + -0.078125f, -0.0859375f, -0.09375f, + -0.1015625f, -0.109375f, -0.1171875f, + -0.125f, -0.140625f, -0.15625f, + -0.171875f, -0.1875f, -0.203125f, + -0.21875f, -0.234375f, -0.25f, + -0.28125f, -0.3125f, -0.34375f, + -0.375f, -0.40625f, -0.4375f, + -0.46875f, -0.5f, -0.5625f, + -0.625f, -0.6875f, -0.75f, + -0.8125f, -0.875f, -0.9375f, + -1.0f, -1.125f, -1.25f, + -1.375f, -1.5f, -1.625f, + -1.75f, -1.875f, -2.0f, + -2.25f, -2.5f, -2.75f, + -3.0f, -3.25f, -3.5f, + -3.75f, -4.0f, -4.5f, + -5.0f, -5.5f, -6.0f, + -6.5f, -7.0f, -7.5f, + -8.0f, -9.0f, -10.0f, + -11.0f, -12.0f, -13.0f, + -14.0f, -15.0f, -16.0f, + -18.0f, -20.0f, -22.0f, + -24.0f, -26.0f, -28.0f, + -30.0f, -32.0f, -36.0f, + -40.0f, -44.0f, -48.0f, + -52.0f, -56.0f, -60.0f, + -64.0f, -72.0f, -80.0f, + -88.0f, -96.0f, -104.0f, + -112.0f, -120.0f, -128.0f, + -144.0f, -160.0f, -176.0f, + -192.0f, -208.0f, -224.0f, + -240.0f, + }; + + return e4m3fnuz_lut[input]; +} + +TEST_CASE(test_fp8_cast_to_float) +{ + std::vector bit_vals(256); + std::iota(bit_vals.begin(), bit_vals.end(), 0); + EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { + migraphx_fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx_fp8::fp8e4m3fnuz::from_bits()); + if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fnuz_to_fp32_value(bit_val))) + { + return true; + } + return migraphx::float_equal(float(fp8_val), fp8e4m3fnuz_to_fp32_value(bit_val)); + })}); +} + +TEST_CASE(test_positive_zero) +{ + float zero = 0.0; + migraphx_fp8::fp8e4m3fnuz fp8_zero(zero); + EXPECT(fp8_zero.is_zero()); + EXPECT(migraphx::float_equal(zero, float(fp8_zero))); +} + +TEST_CASE(test_negative_zero) +{ + float nzero = -0.0; + float pzero = 0.0; + migraphx_fp8::fp8e4m3fnuz fp8_nzero(nzero); + EXPECT(fp8_nzero.is_zero()); + // negative zero gets converted to positive zero + EXPECT(migraphx::float_equal(pzero, float(fp8_nzero))); +} + +TEST_CASE(test_nan_1) +{ + float fnan = std::numeric_limits::quiet_NaN(); + migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); +} + +TEST_CASE(test_nan_2) +{ + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fnuz::from_bits()); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_1) +{ + float finf = std::numeric_limits::infinity(); + // no inf in fp8e4m3fnuz + migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_2) +{ + // no inf in fp8e4m3fnuz, it gets converted to NaNs + migraphx_fp8::fp8e4m3fnuz fp8_nan(std::numeric_limits::infinity()); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_3) +{ + // neg inf + float finf = -1.0 * std::numeric_limits::infinity(); + // no inf in fp8e4m3fnuz + migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); + migraphx_fp8::fp8e4m3fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_max_2) +{ + // gets clipped to max + float fmax = 2 * std::numeric_limits::max(); + migraphx_fp8::fp8e4m3fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx_fp8::fp8e4m3fnuz fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_numeric_lowest_2) +{ + // gets clipped to lowest + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx_fp8::fp8e4m3fnuz fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index b92f1419310..2b407178681 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -237,12 +237,12 @@ TEST_CASE(code_object_hip) std::vector expected_inputs = {input, input}; auto co = migraphx::make_op("gpu::code_object", - {{"code_object", migraphx::value::binary{binaries.front()}}, - {"symbol_name", "add_2"}, - {"global", input.elements()}, - {"local", 1024}, - {"expected_inputs", migraphx::to_value(expected_inputs)}, - {"output", migraphx::to_value(input)}}); + {{"code_object", migraphx::value::binary{binaries.front()}}, + {"symbol_name", "add_2"}, + {"global", input.elements()}, + {"local", 1024}, + {"expected_inputs", migraphx::to_value(expected_inputs)}, + {"output", migraphx::to_value(input)}}); migraphx::program p; auto* mm = p.get_main_module(); @@ -348,7 +348,10 @@ TEST_CASE(compile_math) auto vec_sizes = {2, 4, 6}; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) + if(contains({migraphx::shape::bool_type, + migraphx::shape::fp8e4m3fnuz_type, + migraphx::shape::tuple_type}, + t)) continue; auto name = migraphx::shape::cpp_type(t); if(t == migraphx::shape::half_type) @@ -396,7 +399,10 @@ TEST_CASE(assert_type_min_max) migraphx::gpu::hip_compile_options options; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) + if(contains({migraphx::shape::bool_type, + migraphx::shape::fp8e4m3fnuz_type, + migraphx::shape::tuple_type}, + t)) continue; auto name = migraphx::shape::cpp_type(t); if(t == migraphx::shape::half_type) diff --git a/tools/api/migraphx.h b/tools/api/migraphx.h index 8179cfffd52..ad441280bb7 100644 --- a/tools/api/migraphx.h +++ b/tools/api/migraphx.h @@ -44,7 +44,8 @@ m(int32_type, int32_t) \ m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ - m(uint64_type, uint64_t) + m(uint64_type, uint64_t) \ + m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz) // clang-format on #ifdef __cplusplus @@ -70,7 +71,9 @@ typedef enum } migraphx_shape_datatype_t; #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES -<% generate_c_header() %> +<% + generate_c_header() +%> #ifdef __cplusplus } From 9bc18287635385bc8f21c7f198d01e9981942cb7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 9 Nov 2023 23:38:14 +0000 Subject: [PATCH 002/115] cppcheck fixes --- src/include/migraphx/migraphx_f8_impl.hpp | 21 +++++------------ src/include/migraphx/migraphx_float8.hpp | 28 ++++++++++++----------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/src/include/migraphx/migraphx_f8_impl.hpp b/src/include/migraphx/migraphx_f8_impl.hpp index 03e828483fd..cb6e879e154 100644 --- a/src/include/migraphx/migraphx_f8_impl.hpp +++ b/src/include/migraphx/migraphx_f8_impl.hpp @@ -22,12 +22,8 @@ #ifndef MIGRAPHX_FP8_IMPL_HPP #define MIGRAPHX_FP8_IMPL_HPP -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wreserved-identifier" -#endif -#define CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) +#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) namespace migraphx_f8_impl { namespace detail { template @@ -47,11 +43,10 @@ inline constexpr To bit_cast(From fr) noexcept { static_assert(sizeof(To) == sizeof(From)); #if defined(__GNUC__) and !defined(__clang__) - To x = CONST_FOLD(*reinterpret_cast(&fr)); + return MIGRAPHX_CONST_FOLD(*reinterpret_cast(&fr)); #else - To x = __builtin_bit_cast(To, fr); + return __builtin_bit_cast(To, fr); #endif - return x; } } // namespace detail @@ -102,7 +97,6 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) } else { - // if(__hisinf(x) || __hisnan(x)) if((x & 0x7C00) == 0x7C00) return 0x80; } @@ -112,12 +106,12 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) if(sizeof(T) == 4) { if((x & 0x7F800000) == 0x7F800000) - return signed_inf + (mantissa != 0 ? 1 : 0); + return signed_inf + (mantissa != 0 ? 1 : 0); // cppcheck-suppress InvertedLogic } else { if((x & 0x7C00) == 0x7C00) - return signed_inf + (mantissa != 0 ? 1 : 0); + return signed_inf + (mantissa != 0 ? 1 : 0); // cppcheck-suppress InvertedLogic } } // handle positive zero @@ -241,7 +235,7 @@ this case, the fp16 mantissa should be shift left by 1 */ } } - if(f8_exponent == 0 && mantissa == 0) + if(f8_exponent == 0 and mantissa == 0) return negative_zero_nan ? 0 : (sign << 7); mantissa &= (1 << wm) - 1; return (sign << 7) | (f8_exponent << wm) | mantissa; @@ -314,7 +308,4 @@ constexpr T cast_from_f8(uint8_t x) } } // namespace migraphx_f8_impl -#if defined(__clang__) -#pragma clang diagnostic pop -#endif #endif // MIGRAPHX_FP8_IMPL_HPP diff --git a/src/include/migraphx/migraphx_float8.hpp b/src/include/migraphx/migraphx_float8.hpp index 25d8ecdfd17..96f73ba176a 100644 --- a/src/include/migraphx/migraphx_float8.hpp +++ b/src/include/migraphx/migraphx_float8.hpp @@ -26,7 +26,6 @@ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wold-style-cast" #pragma clang diagnostic ignored "-Wfloat-equal" -#pragma clang diagnostic ignored "-Wmacro-redefined" #pragma clang diagnostic ignored "-Wc++20-extensions" #endif // __clang__ @@ -36,6 +35,7 @@ // We are clipping in down conversion by default #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 + #include #include #include @@ -79,7 +79,7 @@ class numeric_limits; template struct float8 { - uint8_t data; + uint8_t data = 0x00; // default constructor constexpr float8() = default; // default copy constructor @@ -141,7 +141,7 @@ struct float8 } else { - return (data == 0x00) || (data == 0x80); + return (data == 0x00) or (data == 0x80); } } @@ -155,15 +155,15 @@ struct float8 { if(T == migraphx_fp8::f8_type::bf8) { - return (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xfd) || - (data == 0xfe) || (data == 0xff); + return (data == 0x7d) or (data == 0x7e) or (data == 0x7f) or (data == 0xfd) or + (data == 0xfe) or (data == 0xff); } else { - return (data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) || - (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xf9) || - (data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) || - (data == 0xfe) || (data == 0xff); + return (data == 0x79) or (data == 0x7a) or (data == 0x7b) or (data == 0x7c) or + (data == 0x7d) or (data == 0x7e) or (data == 0x7f) or (data == 0xf9) or + (data == 0xfa) or (data == 0xfb) or (data == 0xfc) or (data == 0xfd) or + (data == 0xfe) or (data == 0xff); } } } @@ -178,11 +178,11 @@ struct float8 { if(T == migraphx_fp8::f8_type::bf8) { - return (data == 0x7c) || (data == 0xfc); + return (data == 0x7c) or (data == 0xfc); } else { - return (data == 0x78) || (data == 0xf8); + return (data == 0x78) or (data == 0xf8); } } } @@ -217,10 +217,10 @@ struct float8 inline constexpr bool operator==(const float8& rhs) const { - if((rhs.is_zero() && this->is_zero()) || + if((rhs.is_zero() and this->is_zero()) or (fabs(rhs - *this) < migraphx_fp8::numeric_limits>::epsilon())) return true; - else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf()) + else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) return false; return false; @@ -373,6 +373,8 @@ class numeric_limits> } }; } // namespace migraphx_fp8 + +// ================================================================================================= // define numeric limits for the new data type namespace std { inline bool isfinite(migraphx_fp8::float8 x) // NOLINT From 155a2b172dbdfd00d55e52909c699847de80526c Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 10 Nov 2023 02:48:41 +0000 Subject: [PATCH 003/115] move FNUZ as template parameter --- src/include/migraphx/migraphx_float8.hpp | 123 ++++++++--------------- 1 file changed, 42 insertions(+), 81 deletions(-) diff --git a/src/include/migraphx/migraphx_float8.hpp b/src/include/migraphx/migraphx_float8.hpp index 96f73ba176a..56d598a84b5 100644 --- a/src/include/migraphx/migraphx_float8.hpp +++ b/src/include/migraphx/migraphx_float8.hpp @@ -29,10 +29,6 @@ #pragma clang diagnostic ignored "-Wc++20-extensions" #endif // __clang__ -#ifndef MIGRAPHX_FP8_FNUZ -#define MIGRAPHX_FP8_FNUZ true -#endif // MIGRAPHX_FP8_FNUZ - // We are clipping in down conversion by default #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 @@ -73,10 +69,10 @@ enum class f8_type fp8 = 1 // s1e4m3 }; -template +template class numeric_limits; -template +template struct float8 { uint8_t data = 0x00; @@ -100,11 +96,11 @@ struct float8 { #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx_f8_impl:: - cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( + cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); #else // MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx_f8_impl:: - cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( + cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); #endif // MIGRAPHX_F8_DOWNCAST_CLIPPING } @@ -112,11 +108,11 @@ struct float8 { #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx_f8_impl:: - cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( + cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); #else // MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx_f8_impl:: - cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( + cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); #endif // rocblas_F8_downcast_clipping} } @@ -126,16 +122,14 @@ struct float8 { if constexpr(T == migraphx_fp8::f8_type::fp8) { - return migraphx_f8_impl:: - cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data); + return migraphx_f8_impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data); } // else - return migraphx_f8_impl::cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>( - data); + return migraphx_f8_impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); } inline constexpr bool is_zero() const { - if constexpr(MIGRAPHX_FP8_FNUZ) + if constexpr(FNUZ) { return data == 0x00; } @@ -147,7 +141,7 @@ struct float8 inline constexpr bool is_nan() const { - if constexpr(MIGRAPHX_FP8_FNUZ) + if constexpr(FNUZ) { return data == 0x80; } @@ -170,7 +164,7 @@ struct float8 inline constexpr bool is_inf() const { - if constexpr(MIGRAPHX_FP8_FNUZ) + if constexpr(FNUZ) { return data == 0x80; } @@ -218,7 +212,7 @@ struct float8 inline constexpr bool operator==(const float8& rhs) const { if((rhs.is_zero() and this->is_zero()) or - (fabs(rhs - *this) < migraphx_fp8::numeric_limits>::epsilon())) + (fabs(rhs - *this) < migraphx_fp8::numeric_limits>::epsilon())) return true; else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) return false; @@ -289,123 +283,90 @@ constexpr T F8_Lowest() return T{0xFF, T::from_bits()}; } -using fp8e4m3fnuz = float8; +using fp8e4m3fn = float8; +using fp8e5m2 = float8; +using fp8e4m3fnuz = float8; +using fp8e5m2fnuz = float8; template <> -class numeric_limits> +class numeric_limits { public: - // TODO :figure out epsilon in Hex to make it constexpr - static constexpr migraphx_fp8::float8 epsilon() + static constexpr fp8e4m3fnuz epsilon() { - return migraphx_fp8::float8( - 0x28, migraphx_fp8::float8<>::from_bits()); + return fp8e4m3fnuz(0x28, migraphx_fp8::float8<>::from_bits()); } - static constexpr migraphx_fp8::float8 quiet_NaN() - { - return migraphx_fp8::float8( - MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::float8<>::from_bits()); - } + static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } - static constexpr migraphx_fp8::float8 max() - { - return migraphx_fp8::F8_Max>(); - } + static constexpr fp8e4m3fnuz max() { return migraphx_fp8::F8_Max(); } // TODO figure out Hex value - static migraphx_fp8::float8 min() + static fp8e4m3fnuz min() { - return static_cast>(-1.0f) * - migraphx_fp8::F8_Max>(); + return static_cast(-1.0f) * migraphx_fp8::F8_Max(); } - static constexpr migraphx_fp8::float8 lowest() - { - return migraphx_fp8::F8_Lowest>(); - } + static constexpr fp8e4m3fnuz lowest() { return migraphx_fp8::F8_Lowest(); } - static constexpr migraphx_fp8::float8 infinity() - { - return migraphx_fp8::float8( - MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx_fp8::float8<>::from_bits()); - } + static constexpr fp8e4m3fnuz infinity() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } }; template <> -class numeric_limits> +class numeric_limits { public: - static constexpr migraphx_fp8::float8 epsilon() - { - return migraphx_fp8::float8( - 0x34, migraphx_fp8::float8::from_bits()); - } + static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } - static constexpr migraphx_fp8::float8 quiet_NaN() - { - return migraphx_fp8::float8( - MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d, - migraphx_fp8::float8::from_bits()); - } + static constexpr fp8e5m2fnuz quiet_NaN() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } - static constexpr migraphx_fp8::float8 max() + static constexpr fp8e5m2fnuz max() { - return static_cast>( - migraphx_fp8::F8_Max>()); + return static_cast(migraphx_fp8::F8_Max()); } // TODO figure out constexpr value - static migraphx_fp8::float8 min() + static fp8e5m2fnuz min() { - return static_cast>(float(-1.0f)) * - migraphx_fp8::F8_Max>(); - } - static constexpr migraphx_fp8::float8 lowest() - { - return migraphx_fp8::F8_Lowest>(); + return static_cast(float(-1.0f)) * migraphx_fp8::F8_Max(); } + static constexpr fp8e5m2fnuz lowest() { return migraphx_fp8::F8_Lowest(); } - static constexpr migraphx_fp8::float8 infinity() - { - return migraphx_fp8::float8( - MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c, - migraphx_fp8::float8::from_bits()); - } + static constexpr fp8e5m2fnuz infinity() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } }; } // namespace migraphx_fp8 // ================================================================================================= // define numeric limits for the new data type namespace std { -inline bool isfinite(migraphx_fp8::float8 x) // NOLINT +inline bool isfinite(migraphx_fp8::fp8e4m3fnuz x) // NOLINT { return x.is_inf(); } -inline bool isfinite(migraphx_fp8::float8 x) // NOLINT +inline bool isfinite(migraphx_fp8::fp8e5m2fnuz x) // NOLINT { return x.is_inf(); } -inline bool isnan(migraphx_fp8::float8 x) // NOLINT +inline bool isnan(migraphx_fp8::fp8e4m3fnuz x) // NOLINT { return x.is_nan(); } -inline bool isnan(migraphx_fp8::float8 x) // NOLINT +inline bool isnan(migraphx_fp8::fp8e5m2fnuz x) // NOLINT { return x.is_nan(); } template <> -class numeric_limits> - : public migraphx_fp8::numeric_limits> +class numeric_limits + : public migraphx_fp8::numeric_limits { }; template <> -class numeric_limits> - : public migraphx_fp8::numeric_limits> +class numeric_limits + : public migraphx_fp8::numeric_limits { }; From d9f11e311d6063456f42c45cade774bbfcb1803e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 10 Nov 2023 15:45:53 +0000 Subject: [PATCH 004/115] Fix numeric limits --- src/include/migraphx/migraphx_float8.hpp | 49 ++++++++---------------- 1 file changed, 15 insertions(+), 34 deletions(-) diff --git a/src/include/migraphx/migraphx_float8.hpp b/src/include/migraphx/migraphx_float8.hpp index 56d598a84b5..bc4d96b6c1c 100644 --- a/src/include/migraphx/migraphx_float8.hpp +++ b/src/include/migraphx/migraphx_float8.hpp @@ -271,18 +271,9 @@ inline migraphx_fp8::float8 fabs(migraphx_fp8::float8 v) return v; } -template -constexpr T F8_Max() -{ - return T{0x7F, T::from_bits()}; -} - -template -constexpr T F8_Lowest() -{ - return T{0xFF, T::from_bits()}; -} - +// https://onnx.ai/onnx/technical/float8.html +// these types are not exactly same as GraphCore's FNUZ types. GraphCore's FNUZ types assumes +// exponent bias of 8 and 16 for the FNUZ types, ONNX spec using fp8e4m3fn = float8; using fp8e5m2 = float8; using fp8e4m3fnuz = float8; @@ -292,22 +283,15 @@ template <> class numeric_limits { public: - static constexpr fp8e4m3fnuz epsilon() - { - return fp8e4m3fnuz(0x28, migraphx_fp8::float8<>::from_bits()); - } + static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } - static constexpr fp8e4m3fnuz max() { return migraphx_fp8::F8_Max(); } + static constexpr fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); } + // this is min value that is not DeNorm. DeNorm min is 0x01 + static constexpr fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); } - // TODO figure out Hex value - static fp8e4m3fnuz min() - { - return static_cast(-1.0f) * migraphx_fp8::F8_Max(); - } - - static constexpr fp8e4m3fnuz lowest() { return migraphx_fp8::F8_Lowest(); } + static constexpr fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz infinity() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } }; @@ -320,16 +304,12 @@ class numeric_limits static constexpr fp8e5m2fnuz quiet_NaN() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } - static constexpr fp8e5m2fnuz max() - { - return static_cast(migraphx_fp8::F8_Max()); - } - // TODO figure out constexpr value - static fp8e5m2fnuz min() - { - return static_cast(float(-1.0f)) * migraphx_fp8::F8_Max(); - } - static constexpr fp8e5m2fnuz lowest() { return migraphx_fp8::F8_Lowest(); } + static constexpr fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); } + // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make + // this distinction. For the floating points we would end up using lowest most of the times. + static constexpr fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); } + + static constexpr fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); } static constexpr fp8e5m2fnuz infinity() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } }; @@ -338,6 +318,7 @@ class numeric_limits // ================================================================================================= // define numeric limits for the new data type namespace std { + inline bool isfinite(migraphx_fp8::fp8e4m3fnuz x) // NOLINT { return x.is_inf(); From 4e9d51f0314fb3e068eff4d180595dce0d2572bc Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 10 Nov 2023 22:10:10 +0000 Subject: [PATCH 005/115] Working FNUZ and FN --- src/include/migraphx/migraphx_f8_impl.hpp | 70 ++++++-- src/include/migraphx/migraphx_float8.hpp | 133 +++++++------- test/fp8e4m3fn.cpp | 202 ++++++++++++++++++++++ test/fp8e4m3fnuz.cpp | 12 +- 4 files changed, 335 insertions(+), 82 deletions(-) create mode 100644 test/fp8e4m3fn.cpp diff --git a/src/include/migraphx/migraphx_f8_impl.hpp b/src/include/migraphx/migraphx_f8_impl.hpp index cb6e879e154..f55a52ec12d 100644 --- a/src/include/migraphx/migraphx_f8_impl.hpp +++ b/src/include/migraphx/migraphx_f8_impl.hpp @@ -86,6 +86,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) } uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + uint32_t signed_max = (sign << 7) + ((((1 << we) - 1) << wm) + ((1 << wm) - 1)); + if(not negative_zero_nan) + { + signed_max = (wm == 2) ? (signed_max - 4) : (signed_max - 1); + } // Deal with inf and NaNs if(negative_zero_nan) @@ -103,15 +108,50 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) } else { - if(sizeof(T) == 4) + // calculate most common NaN mantissa for FP8, which is all Ones in binary + uint32_t nan_mantissa = 1; + for(auto i = 1; i < wm; ++i) { - if((x & 0x7F800000) == 0x7F800000) - return signed_inf + (mantissa != 0 ? 1 : 0); // cppcheck-suppress InvertedLogic + nan_mantissa |= (nan_mantissa << 1); } - else + // TODO: abstract duplicate branches + if(sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) { - if((x & 0x7C00) == 0x7C00) - return signed_inf + (mantissa != 0 ? 1 : 0); // cppcheck-suppress InvertedLogic + // infinity + if(mantissa == 0) + { + if(sign == 0) + { + return (wm == 2) ? 0x7B : 0x7E; + } + else + { + return (wm == 2) ? 0xFB : 0xFE; + } + } + else + { // NaNs + return signed_inf + nan_mantissa; + } + } + else if(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)) + { + // infinity + if(mantissa == 0) + { + if(sign == 0) + { + return (wm == 2) ? 0x7B : 0x7E; + } + else + { + return (wm == 2) ? 0xFB : 0xFE; + } + } + else + { // NaNs + return signed_inf + nan_mantissa; + } } } // handle positive zero @@ -222,16 +262,24 @@ this case, the fp16 mantissa should be shift left by 1 */ // above range: quantize to maximum possible float of the same sign const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + // TODO: this is ugly, need better way to handle out of range values if(f8_exponent > max_exp) { if(clip) { - mantissa = (1 << wm) - 1; - f8_exponent = max_exp; + return signed_max; } else { - return signed_inf; + if(negative_zero_nan) + { + return 0x80; + } + else + { + uint32_t tmp_signed_max = (sign << 7) + ((((1 << we) - 1) << wm) + ((1 << wm) - 1)); + return (wm == 2) ? signed_inf : tmp_signed_max; + } } } @@ -273,8 +321,10 @@ constexpr T cast_from_f8(uint8_t x) { if(x == 0x80) return fNeg0; - if(exponent == ((1 << we) - 1)) + if(exponent == ((1 << we) - 1) and wm == 2) return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + else if(wm == 3 and (x == 0x7F or x == 0xFF)) + return fNaN; } typename detail::conditional::type retval; diff --git a/src/include/migraphx/migraphx_float8.hpp b/src/include/migraphx/migraphx_float8.hpp index bc4d96b6c1c..dd111bbc4d2 100644 --- a/src/include/migraphx/migraphx_float8.hpp +++ b/src/include/migraphx/migraphx_float8.hpp @@ -79,7 +79,7 @@ struct float8 // default constructor constexpr float8() = default; // default copy constructor - constexpr float8(const float8& y) = default; + constexpr float8(const float8& y) = default; struct from_bits_t { }; @@ -149,15 +149,12 @@ struct float8 { if(T == migraphx_fp8::f8_type::bf8) { - return (data == 0x7d) or (data == 0x7e) or (data == 0x7f) or (data == 0xfd) or - (data == 0xfe) or (data == 0xff); + return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or + (data == 0xFE) or (data == 0xFF); } else { - return (data == 0x79) or (data == 0x7a) or (data == 0x7b) or (data == 0x7c) or - (data == 0x7d) or (data == 0x7e) or (data == 0x7f) or (data == 0xf9) or - (data == 0xfa) or (data == 0xfb) or (data == 0xfc) or (data == 0xfd) or - (data == 0xfe) or (data == 0xff); + return (data == 0x7F) or (data == 0xFF); } } } @@ -172,11 +169,12 @@ struct float8 { if(T == migraphx_fp8::f8_type::bf8) { - return (data == 0x7c) or (data == 0xfc); + return (data == 0x7C) or (data == 0xFC); } else { - return (data == 0x78) or (data == 0xf8); + // no infinities in e4m3fn, represent them as NaNs + return (data == 0x7F) or (data == 0xFF); } } } @@ -211,12 +209,12 @@ struct float8 inline constexpr bool operator==(const float8& rhs) const { - if((rhs.is_zero() and this->is_zero()) or - (fabs(rhs - *this) < migraphx_fp8::numeric_limits>::epsilon())) + if(rhs.is_zero() and this->is_zero()) return true; else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) return false; - + else if(this->data == rhs.data) + return true; return false; } @@ -272,8 +270,6 @@ inline migraphx_fp8::float8 fabs(migraphx_fp8::float8 v) } // https://onnx.ai/onnx/technical/float8.html -// these types are not exactly same as GraphCore's FNUZ types. GraphCore's FNUZ types assumes -// exponent bias of 8 and 16 for the FNUZ types, ONNX spec using fp8e4m3fn = float8; using fp8e5m2 = float8; using fp8e4m3fnuz = float8; @@ -282,6 +278,8 @@ using fp8e5m2fnuz = float8; template <> class numeric_limits { + static constexpr bool has_infinity = false; + public: static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); } @@ -292,13 +290,30 @@ class numeric_limits static constexpr fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); } +}; + +template <> +class numeric_limits +{ + static constexpr bool has_infinity = false; - static constexpr fp8e4m3fnuz infinity() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } + public: + static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); } + + static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); } + + static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); } + // this is min value that is not DeNorm. DeNorm min is 0x01 + static constexpr fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); } + + static constexpr fp8e4m3fn lowest() { return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); } }; template <> class numeric_limits { + static constexpr bool has_infinity = false; + public: static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } @@ -310,62 +325,56 @@ class numeric_limits static constexpr fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); } static constexpr fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); } - - static constexpr fp8e5m2fnuz infinity() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } }; -} // namespace migraphx_fp8 - -// ================================================================================================= -// define numeric limits for the new data type -namespace std { - -inline bool isfinite(migraphx_fp8::fp8e4m3fnuz x) // NOLINT -{ - return x.is_inf(); -} - -inline bool isfinite(migraphx_fp8::fp8e5m2fnuz x) // NOLINT -{ - return x.is_inf(); -} - -inline bool isnan(migraphx_fp8::fp8e4m3fnuz x) // NOLINT -{ - return x.is_nan(); -} - -inline bool isnan(migraphx_fp8::fp8e5m2fnuz x) // NOLINT -{ - return x.is_nan(); -} template <> -class numeric_limits - : public migraphx_fp8::numeric_limits +class numeric_limits { -}; + public: + static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); } + // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs + static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } -template <> -class numeric_limits - : public migraphx_fp8::numeric_limits -{ -}; + static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } + // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make + // this distinction. For the floating points we would end up using lowest most of the times. + static constexpr fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } -template -struct common_type : std::common_type // NOLINT -{ + static constexpr fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } + // 7C and FC both are infinity + static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); } }; +} // namespace migraphx_fp8 -template -struct common_type : std::common_type // NOLINT -{ -}; +// ================================================================================================= +// define numeric limits for the new data type +namespace std { -template <> -struct common_type -{ - using type = float; -}; +#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ + inline bool isfinite(T x) { return x.is_inf(); } \ + inline bool isnan(T x) { return x.is_nan(); } \ + template <> \ + class numeric_limits : public migraphx_fp8::numeric_limits \ + { \ + }; \ + template \ + struct common_type : std::common_type \ + { \ + }; \ + template \ + struct common_type : std::common_type \ + { \ + }; \ + template <> \ + struct common_type \ + { \ + using type = T; \ + }; + +MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fn) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fnuz) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2fnuz) } // namespace std // ================================================================================================= diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp new file mode 100644 index 00000000000..cf3fe79c167 --- /dev/null +++ b/test/fp8e4m3fn.cpp @@ -0,0 +1,202 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include "test.hpp" + +#include + +float fp8e4m3fn_to_fp32_value(uint8_t input) +{ + constexpr std::array e4m3fnuz_lut = { + 0.0, 0.001953125, 0.00390625, 0.005859375, + 0.0078125, 0.009765625, 0.01171875, 0.013671875, + 0.015625, 0.017578125, 0.01953125, 0.021484375, + 0.0234375, 0.025390625, 0.02734375, 0.029296875, + 0.03125, 0.03515625, 0.0390625, 0.04296875, + 0.046875, 0.05078125, 0.0546875, 0.05859375, + 0.0625, 0.0703125, 0.078125, 0.0859375, + 0.09375, 0.1015625, 0.109375, 0.1171875, + 0.125, 0.140625, 0.15625, 0.171875, + 0.1875, 0.203125, 0.21875, 0.234375, + 0.25, 0.28125, 0.3125, 0.34375, + 0.375, 0.40625, 0.4375, 0.46875, + 0.5, 0.5625, 0.625, 0.6875, + 0.75, 0.8125, 0.875, 0.9375, + 1.0, 1.125, 1.25, 1.375, + 1.5, 1.625, 1.75, 1.875, + 2.0, 2.25, 2.5, 2.75, + 3.0, 3.25, 3.5, 3.75, + 4.0, 4.5, 5.0, 5.5, + 6.0, 6.5, 7.0, 7.5, + 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, 14.0, 15.0, + 16.0, 18.0, 20.0, 22.0, + 24.0, 26.0, 28.0, 30.0, + 32.0, 36.0, 40.0, 44.0, + 48.0, 52.0, 56.0, 60.0, + 64.0, 72.0, 80.0, 88.0, + 96.0, 104.0, 112.0, 120.0, + 128.0, 144.0, 160.0, 176.0, + 192.0, 208.0, 224.0, 240.0, + 256.0, 288.0, 320.0, 352.0, + 384.0, 416.0, 448.0, std::numeric_limits::quiet_NaN(), + -0.0, -0.001953125, -0.00390625, -0.005859375, + -0.0078125, -0.009765625, -0.01171875, -0.013671875, + -0.015625, -0.017578125, -0.01953125, -0.021484375, + -0.0234375, -0.025390625, -0.02734375, -0.029296875, + -0.03125, -0.03515625, -0.0390625, -0.04296875, + -0.046875, -0.05078125, -0.0546875, -0.05859375, + -0.0625, -0.0703125, -0.078125, -0.0859375, + -0.09375, -0.1015625, -0.109375, -0.1171875, + -0.125, -0.140625, -0.15625, -0.171875, + -0.1875, -0.203125, -0.21875, -0.234375, + -0.25, -0.28125, -0.3125, -0.34375, + -0.375, -0.40625, -0.4375, -0.46875, + -0.5, -0.5625, -0.625, -0.6875, + -0.75, -0.8125, -0.875, -0.9375, + -1.0, -1.125, -1.25, -1.375, + -1.5, -1.625, -1.75, -1.875, + -2.0, -2.25, -2.5, -2.75, + -3.0, -3.25, -3.5, -3.75, + -4.0, -4.5, -5.0, -5.5, + -6.0, -6.5, -7.0, -7.5, + -8.0, -9.0, -10.0, -11.0, + -12.0, -13.0, -14.0, -15.0, + -16.0, -18.0, -20.0, -22.0, + -24.0, -26.0, -28.0, -30.0, + -32.0, -36.0, -40.0, -44.0, + -48.0, -52.0, -56.0, -60.0, + -64.0, -72.0, -80.0, -88.0, + -96.0, -104.0, -112.0, -120.0, + -128.0, -144.0, -160.0, -176.0, + -192.0, -208.0, -224.0, -240.0, + -256.0, -288.0, -320.0, -352.0, + -384.0, -416.0, -448.0, std::numeric_limits::quiet_NaN(), + + }; + + return e4m3fnuz_lut[input]; +} + +TEST_CASE(test_fp8_cast_to_float) +{ + std::vector bit_vals(256); + std::iota(bit_vals.begin(), bit_vals.end(), 0); + EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { + migraphx_fp8::fp8e4m3fn fp8_val(bit_val, migraphx_fp8::fp8e4m3fn::from_bits()); + if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val))) + { + return true; + } + return migraphx::float_equal(float(fp8_val), fp8e4m3fn_to_fp32_value(bit_val)); + })}); +} + +TEST_CASE(test_positive_zero) +{ + float zero = 0.0; + migraphx_fp8::fp8e4m3fn fp8_zero(zero); + EXPECT(fp8_zero.is_zero()); + EXPECT(migraphx::float_equal(zero, float(fp8_zero))); +} + +TEST_CASE(test_negative_zero) +{ + float nzero = -0.0; + migraphx_fp8::fp8e4m3fn fp8_nzero(nzero); + EXPECT(fp8_nzero.is_zero()); + // negative zero is preserved for fp8e4m3fn + EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); +} + +TEST_CASE(test_nan_1) +{ + float fnan = std::numeric_limits::quiet_NaN(); + migraphx_fp8::fp8e4m3fn fp8_nan(fnan); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); +} + +TEST_CASE(test_nan_2) +{ + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx_fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fn::from_bits()); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_1) +{ + float finf = std::numeric_limits::infinity(); + // no inf in fp8e4m3fn, it gets clipped to max() + migraphx_fp8::fp8e4m3fn fp8_max(finf); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_infinity_2) +{ + // neg inf + float finf = -1.0 * std::numeric_limits::infinity(); + // no inf in fp8e4m3fn, it gets clipped to lowest + migraphx_fp8::fp8e4m3fn fp8_lowest(finf); + EXPECT(bool{fp8_lowest == std::numeric_limits::lowest()}); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); + migraphx_fp8::fp8e4m3fn fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_max_2) +{ + // gets clipped to max + float fmax = 2 * std::numeric_limits::max(); + migraphx_fp8::fp8e4m3fn fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx_fp8::fp8e4m3fn fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_numeric_lowest_2) +{ + // gets clipped to lowest + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx_fp8::fp8e4m3fn fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_max_eq_lowest) {} +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp index 02d7fb77fa5..d5c6a753d19 100644 --- a/test/fp8e4m3fnuz.cpp +++ b/test/fp8e4m3fnuz.cpp @@ -176,25 +176,17 @@ TEST_CASE(test_nan_2) TEST_CASE(test_infinity_1) { float finf = std::numeric_limits::infinity(); - // no inf in fp8e4m3fnuz + // no inf in fp8e4m3fnuz it gets clipped to Nans migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(float(fp8_nan))); } TEST_CASE(test_infinity_2) -{ - // no inf in fp8e4m3fnuz, it gets converted to NaNs - migraphx_fp8::fp8e4m3fnuz fp8_nan(std::numeric_limits::infinity()); - EXPECT(fp8_nan.is_nan()); - EXPECT(std::isnan(float(fp8_nan))); -} - -TEST_CASE(test_infinity_3) { // neg inf float finf = -1.0 * std::numeric_limits::infinity(); - // no inf in fp8e4m3fnuz + // no inf in fp8e4m3fnuz it gets clipped to NaNs migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(float(fp8_nan))); From 7639c28097a85f87d85549a30666845d2f816f58 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 10 Nov 2023 22:12:48 +0000 Subject: [PATCH 006/115] use float equal --- test/fp8e4m3fn.cpp | 6 +++++- test/fp8e4m3fnuz.cpp | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp index cf3fe79c167..0601ac696c5 100644 --- a/test/fp8e4m3fn.cpp +++ b/test/fp8e4m3fn.cpp @@ -198,5 +198,9 @@ TEST_CASE(test_numeric_lowest_2) EXPECT(fp8_lowest == std::numeric_limits::lowest()); } -TEST_CASE(test_max_eq_lowest) {} +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp index d5c6a753d19..d1e3cd72542 100644 --- a/test/fp8e4m3fnuz.cpp +++ b/test/fp8e4m3fnuz.cpp @@ -222,4 +222,9 @@ TEST_CASE(test_numeric_lowest_2) EXPECT(fp8_lowest == std::numeric_limits::lowest()); } +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} int main(int argc, const char* argv[]) { test::run(argc, argv); } From a6372c50ee5c3f4c6e6b2ee7d4475b6bd3d2ea7f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 10 Nov 2023 22:35:36 +0000 Subject: [PATCH 007/115] add test for fp8e5m2 --- test/fp8e5m2.cpp | 403 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 403 insertions(+) create mode 100644 test/fp8e5m2.cpp diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp new file mode 100644 index 00000000000..b4916ec35e8 --- /dev/null +++ b/test/fp8e5m2.cpp @@ -0,0 +1,403 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include "test.hpp" + +#include + +float fp8e5m2_to_fp32_value(uint8_t input) +{ + constexpr std::array e4m3fnuz_lut = {{ + 0.0, + 1.52587890625e-05, + 3.0517578125e-05, + 4.57763671875e-05, + 6.103515625e-05, + 7.62939453125e-05, + 9.1552734375e-05, + 0.0001068115234375, + 0.0001220703125, + 0.000152587890625, + 0.00018310546875, + 0.000213623046875, + 0.000244140625, + 0.00030517578125, + 0.0003662109375, + 0.00042724609375, + 0.00048828125, + 0.0006103515625, + 0.000732421875, + 0.0008544921875, + 0.0009765625, + 0.001220703125, + 0.00146484375, + 0.001708984375, + 0.001953125, + 0.00244140625, + 0.0029296875, + 0.00341796875, + 0.00390625, + 0.0048828125, + 0.005859375, + 0.0068359375, + 0.0078125, + 0.009765625, + 0.01171875, + 0.013671875, + 0.015625, + 0.01953125, + 0.0234375, + 0.02734375, + 0.03125, + 0.0390625, + 0.046875, + 0.0546875, + 0.0625, + 0.078125, + 0.09375, + 0.109375, + 0.125, + 0.15625, + 0.1875, + 0.21875, + 0.25, + 0.3125, + 0.375, + 0.4375, + 0.5, + 0.625, + 0.75, + 0.875, + 1.0, + 1.25, + 1.5, + 1.75, + 2.0, + 2.5, + 3.0, + 3.5, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 10.0, + 12.0, + 14.0, + 16.0, + 20.0, + 24.0, + 28.0, + 32.0, + 40.0, + 48.0, + 56.0, + 64.0, + 80.0, + 96.0, + 112.0, + 128.0, + 160.0, + 192.0, + 224.0, + 256.0, + 320.0, + 384.0, + 448.0, + 512.0, + 640.0, + 768.0, + 896.0, + 1024.0, + 1280.0, + 1536.0, + 1792.0, + 2048.0, + 2560.0, + 3072.0, + 3584.0, + 4096.0, + 5120.0, + 6144.0, + 7168.0, + 8192.0, + 10240.0, + 12288.0, + 14336.0, + 16384.0, + 20480.0, + 24576.0, + 28672.0, + 32768.0, + 40960.0, + 49152.0, + 57344.0, + std::numeric_limits::infinity(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.0, + -1.52587890625e-05, + -3.0517578125e-05, + -4.57763671875e-05, + -6.103515625e-05, + -7.62939453125e-05, + -9.1552734375e-05, + -0.0001068115234375, + -0.0001220703125, + -0.000152587890625, + -0.00018310546875, + -0.000213623046875, + -0.000244140625, + -0.00030517578125, + -0.0003662109375, + -0.00042724609375, + -0.00048828125, + -0.0006103515625, + -0.000732421875, + -0.0008544921875, + -0.0009765625, + -0.001220703125, + -0.00146484375, + -0.001708984375, + -0.001953125, + -0.00244140625, + -0.0029296875, + -0.00341796875, + -0.00390625, + -0.0048828125, + -0.005859375, + -0.0068359375, + -0.0078125, + -0.009765625, + -0.01171875, + -0.013671875, + -0.015625, + -0.01953125, + -0.0234375, + -0.02734375, + -0.03125, + -0.0390625, + -0.046875, + -0.0546875, + -0.0625, + -0.078125, + -0.09375, + -0.109375, + -0.125, + -0.15625, + -0.1875, + -0.21875, + -0.25, + -0.3125, + -0.375, + -0.4375, + -0.5, + -0.625, + -0.75, + -0.875, + -1.0, + -1.25, + -1.5, + -1.75, + -2.0, + -2.5, + -3.0, + -3.5, + -4.0, + -5.0, + -6.0, + -7.0, + -8.0, + -10.0, + -12.0, + -14.0, + -16.0, + -20.0, + -24.0, + -28.0, + -32.0, + -40.0, + -48.0, + -56.0, + -64.0, + -80.0, + -96.0, + -112.0, + -128.0, + -160.0, + -192.0, + -224.0, + -256.0, + -320.0, + -384.0, + -448.0, + -512.0, + -640.0, + -768.0, + -896.0, + -1024.0, + -1280.0, + -1536.0, + -1792.0, + -2048.0, + -2560.0, + -3072.0, + -3584.0, + -4096.0, + -5120.0, + -6144.0, + -7168.0, + -8192.0, + -10240.0, + -12288.0, + -14336.0, + -16384.0, + -20480.0, + -24576.0, + -28672.0, + -32768.0, + -40960.0, + -49152.0, + -57344.0, + -1.0 * std::numeric_limits::infinity(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + } + + }; + + return e4m3fnuz_lut[input]; +} + +TEST_CASE(test_fp8_cast_to_float) +{ + std::vector bit_vals(256); + std::iota(bit_vals.begin(), bit_vals.end(), 0); + EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { + migraphx_fp8::fp8e5m2 fp8_val(bit_val, migraphx_fp8::fp8e5m2::from_bits()); + if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2_to_fp32_value(bit_val))) + { + return true; + } + else if(std::isinf(float(fp8_val)) and std::isinf(fp8e5m2_to_fp32_value(bit_val))) + { + return true; + } + return migraphx::float_equal(float(fp8_val), fp8e5m2_to_fp32_value(bit_val)); + })}); +} + +TEST_CASE(test_positive_zero) +{ + float zero = 0.0; + migraphx_fp8::fp8e5m2 fp8_zero(zero); + EXPECT(fp8_zero.is_zero()); + EXPECT(migraphx::float_equal(zero, float(fp8_zero))); +} + +TEST_CASE(test_negative_zero) +{ + float nzero = -0.0; + migraphx_fp8::fp8e5m2 fp8_nzero(nzero); + EXPECT(fp8_nzero.is_zero()); + // negative zero is preserved for fp8e5m2 + EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); +} + +TEST_CASE(test_nan_1) +{ + float fnan = std::numeric_limits::quiet_NaN(); + migraphx_fp8::fp8e5m2 fp8_nan(fnan); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); +} + +TEST_CASE(test_nan_2) +{ + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx_fp8::fp8e5m2 fp8_nan(fnan.data, migraphx_fp8::fp8e5m2::from_bits()); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_1) +{ + // float infinity should get clipped to max + float finf = std::numeric_limits::infinity(); + migraphx_fp8::fp8e5m2 fp8_max(finf); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_infinity_2) +{ + // neg inf + float finf = -1.0 * std::numeric_limits::infinity(); + // no inf in fp8e5m2, it gets clipped to lowest + migraphx_fp8::fp8e5m2 fp8_lowest(finf); + EXPECT(bool{fp8_lowest == std::numeric_limits::lowest()}); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); + migraphx_fp8::fp8e5m2 fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_max_2) +{ + // gets clipped to max + float fmax = 2 * std::numeric_limits::max(); + migraphx_fp8::fp8e5m2 fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx_fp8::fp8e5m2 fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_numeric_lowest_2) +{ + // gets clipped to lowest + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx_fp8::fp8e5m2 fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} +int main(int argc, const char* argv[]) { test::run(argc, argv); } From 439ea40dcba8634ffa8a81b7e1092ebea55caf8b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 10 Nov 2023 22:54:15 +0000 Subject: [PATCH 008/115] add test for fp8e5m2fnuz --- test/fp8e5m2fnuz.cpp | 400 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 400 insertions(+) create mode 100644 test/fp8e5m2fnuz.cpp diff --git a/test/fp8e5m2fnuz.cpp b/test/fp8e5m2fnuz.cpp new file mode 100644 index 00000000000..da447ab1996 --- /dev/null +++ b/test/fp8e5m2fnuz.cpp @@ -0,0 +1,400 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include "test.hpp" + +#include + +float fp8e5m2fnuz_to_fp32_value(uint8_t input) +{ + constexpr std::array e4m3fnuz_lut = { + 0.0, + 7.62939453125e-06, + 1.52587890625e-05, + 2.288818359375e-05, + 3.0517578125e-05, + 3.814697265625e-05, + 4.57763671875e-05, + 5.340576171875e-05, + 6.103515625e-05, + 7.62939453125e-05, + 9.1552734375e-05, + 0.0001068115234375, + 0.0001220703125, + 0.000152587890625, + 0.00018310546875, + 0.000213623046875, + 0.000244140625, + 0.00030517578125, + 0.0003662109375, + 0.00042724609375, + 0.00048828125, + 0.0006103515625, + 0.000732421875, + 0.0008544921875, + 0.0009765625, + 0.001220703125, + 0.00146484375, + 0.001708984375, + 0.001953125, + 0.00244140625, + 0.0029296875, + 0.00341796875, + 0.00390625, + 0.0048828125, + 0.005859375, + 0.0068359375, + 0.0078125, + 0.009765625, + 0.01171875, + 0.013671875, + 0.015625, + 0.01953125, + 0.0234375, + 0.02734375, + 0.03125, + 0.0390625, + 0.046875, + 0.0546875, + 0.0625, + 0.078125, + 0.09375, + 0.109375, + 0.125, + 0.15625, + 0.1875, + 0.21875, + 0.25, + 0.3125, + 0.375, + 0.4375, + 0.5, + 0.625, + 0.75, + 0.875, + 1.0, + 1.25, + 1.5, + 1.75, + 2.0, + 2.5, + 3.0, + 3.5, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 10.0, + 12.0, + 14.0, + 16.0, + 20.0, + 24.0, + 28.0, + 32.0, + 40.0, + 48.0, + 56.0, + 64.0, + 80.0, + 96.0, + 112.0, + 128.0, + 160.0, + 192.0, + 224.0, + 256.0, + 320.0, + 384.0, + 448.0, + 512.0, + 640.0, + 768.0, + 896.0, + 1024.0, + 1280.0, + 1536.0, + 1792.0, + 2048.0, + 2560.0, + 3072.0, + 3584.0, + 4096.0, + 5120.0, + 6144.0, + 7168.0, + 8192.0, + 10240.0, + 12288.0, + 14336.0, + 16384.0, + 20480.0, + 24576.0, + 28672.0, + 32768.0, + 40960.0, + 49152.0, + 57344.0, + std::numeric_limits::quiet_NaN(), + -7.62939453125e-06, + -1.52587890625e-05, + -2.288818359375e-05, + -3.0517578125e-05, + -3.814697265625e-05, + -4.57763671875e-05, + -5.340576171875e-05, + -6.103515625e-05, + -7.62939453125e-05, + -9.1552734375e-05, + -0.0001068115234375, + -0.0001220703125, + -0.000152587890625, + -0.00018310546875, + -0.000213623046875, + -0.000244140625, + -0.00030517578125, + -0.0003662109375, + -0.00042724609375, + -0.00048828125, + -0.0006103515625, + -0.000732421875, + -0.0008544921875, + -0.0009765625, + -0.001220703125, + -0.00146484375, + -0.001708984375, + -0.001953125, + -0.00244140625, + -0.0029296875, + -0.00341796875, + -0.00390625, + -0.0048828125, + -0.005859375, + -0.0068359375, + -0.0078125, + -0.009765625, + -0.01171875, + -0.013671875, + -0.015625, + -0.01953125, + -0.0234375, + -0.02734375, + -0.03125, + -0.0390625, + -0.046875, + -0.0546875, + -0.0625, + -0.078125, + -0.09375, + -0.109375, + -0.125, + -0.15625, + -0.1875, + -0.21875, + -0.25, + -0.3125, + -0.375, + -0.4375, + -0.5, + -0.625, + -0.75, + -0.875, + -1.0, + -1.25, + -1.5, + -1.75, + -2.0, + -2.5, + -3.0, + -3.5, + -4.0, + -5.0, + -6.0, + -7.0, + -8.0, + -10.0, + -12.0, + -14.0, + -16.0, + -20.0, + -24.0, + -28.0, + -32.0, + -40.0, + -48.0, + -56.0, + -64.0, + -80.0, + -96.0, + -112.0, + -128.0, + -160.0, + -192.0, + -224.0, + -256.0, + -320.0, + -384.0, + -448.0, + -512.0, + -640.0, + -768.0, + -896.0, + -1024.0, + -1280.0, + -1536.0, + -1792.0, + -2048.0, + -2560.0, + -3072.0, + -3584.0, + -4096.0, + -5120.0, + -6144.0, + -7168.0, + -8192.0, + -10240.0, + -12288.0, + -14336.0, + -16384.0, + -20480.0, + -24576.0, + -28672.0, + -32768.0, + -40960.0, + -49152.0, + -57344.0, + }; + + return e4m3fnuz_lut[input]; +} + +TEST_CASE(test_fp8_cast_to_float) +{ + std::vector bit_vals(256); + std::iota(bit_vals.begin(), bit_vals.end(), 0); + EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { + migraphx_fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx_fp8::fp8e5m2fnuz::from_bits()); + if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2fnuz_to_fp32_value(bit_val))) + { + return true; + } + return migraphx::float_equal(float(fp8_val), fp8e5m2fnuz_to_fp32_value(bit_val)); + })}); +} + +TEST_CASE(test_positive_zero) +{ + float zero = 0.0; + migraphx_fp8::fp8e5m2fnuz fp8_zero(zero); + EXPECT(fp8_zero.is_zero()); + EXPECT(migraphx::float_equal(zero, float(fp8_zero))); +} + +TEST_CASE(test_negative_zero) +{ + float nzero = -0.0; + float pzero = 0.0; + migraphx_fp8::fp8e5m2fnuz fp8_nzero(nzero); + EXPECT(fp8_nzero.is_zero()); + // negative zero gets converted to positive zero + EXPECT(migraphx::float_equal(pzero, float(fp8_nzero))); +} + +TEST_CASE(test_nan_1) +{ + float fnan = std::numeric_limits::quiet_NaN(); + migraphx_fp8::fp8e5m2fnuz fp8_nan(fnan); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); +} + +TEST_CASE(test_nan_2) +{ + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx_fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e5m2fnuz::from_bits()); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_1) +{ + float finf = std::numeric_limits::infinity(); + // no inf in fp8e5m2fnuz it gets clipped to Nans + migraphx_fp8::fp8e5m2fnuz fp8_nan(finf); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_2) +{ + // neg inf + float finf = -1.0 * std::numeric_limits::infinity(); + // no inf in fp8e5m2fnuz it gets clipped to NaNs + migraphx_fp8::fp8e5m2fnuz fp8_nan(finf); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); + migraphx_fp8::fp8e5m2fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_max_2) +{ + // gets clipped to max + float fmax = 2 * std::numeric_limits::max(); + migraphx_fp8::fp8e5m2fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx_fp8::fp8e5m2fnuz fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_numeric_lowest_2) +{ + // gets clipped to lowest + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx_fp8::fp8e5m2fnuz fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} +int main(int argc, const char* argv[]) { test::run(argc, argv); } From 183db78a1473f60d0760e39a1dd0789f806c9478 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 10 Nov 2023 23:09:18 +0000 Subject: [PATCH 009/115] refactor add some comments --- src/include/migraphx/migraphx_f8_impl.hpp | 89 +++++++++-------------- src/include/migraphx/migraphx_float8.hpp | 7 +- 2 files changed, 39 insertions(+), 57 deletions(-) diff --git a/src/include/migraphx/migraphx_f8_impl.hpp b/src/include/migraphx/migraphx_f8_impl.hpp index f55a52ec12d..91871fe643e 100644 --- a/src/include/migraphx/migraphx_f8_impl.hpp +++ b/src/include/migraphx/migraphx_f8_impl.hpp @@ -85,15 +85,18 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) bias = 15; } - uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); - uint32_t signed_max = (sign << 7) + ((((1 << we) - 1) << wm) + ((1 << wm) - 1)); + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + uint32_t signed_all_ones = (sign << 7) + ((((1 << we) - 1) << wm) + ((1 << wm) - 1)); + + // Calcualte maximum singed value FLT_MAX, FLT_MIN + uint32_t signed_max = signed_all_ones; if(not negative_zero_nan) { signed_max = (wm == 2) ? (signed_max - 4) : (signed_max - 1); } // Deal with inf and NaNs - if(negative_zero_nan) + if(negative_zero_nan) // For the FNUZ cases, it is simple just return NaNs { if(sizeof(T) == 4) { @@ -114,27 +117,8 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) { nan_mantissa |= (nan_mantissa << 1); } - // TODO: abstract duplicate branches - if(sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) - { - // infinity - if(mantissa == 0) - { - if(sign == 0) - { - return (wm == 2) ? 0x7B : 0x7E; - } - else - { - return (wm == 2) ? 0xFB : 0xFE; - } - } - else - { // NaNs - return signed_inf + nan_mantissa; - } - } - else if(sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00)) + if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or + (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) { // infinity if(mantissa == 0) @@ -160,7 +144,7 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) // handle negative zero if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) { - if(negative_zero_nan) + if(negative_zero_nan) // For FNUZ types neg zero is just positive zero { return 0; } @@ -170,29 +154,29 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) } } - // First need to check if it is normal or denorm as there is a difference of implict 1 - // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift - // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for - // RNE, no need to add rng. Then probably need to check whether there is carry and adjust - // exponent and mantissa again + /* First need to check if it is normal or denorm as there is a difference of implict 1 + Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + RNE, no need to add rng. Then probably need to check whether there is carry and adjust + exponent and mantissa again*/ // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal - // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) - // f8_exponent is the converted f8 exponent with bias encoding - // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, - // the difference needs to be adjusted and mantissa shifted + /* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + f8_exponent is the converted f8 exponent with bias encoding + exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + the difference needs to be adjusted and mantissa shifted*/ int act_exponent, f8_exponent, exponent_diff; if(exponent == 0) { // fp32/fp16 is in denormal. /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 -here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has -exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in -fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers -where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In -this case, the fp16 mantissa should be shift left by 1 */ + here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal + has exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some + numbers in fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is + 2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 + are bf8 (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ act_exponent = exponent - bias + 1; exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal @@ -203,10 +187,10 @@ this case, the fp16 mantissa should be shift left by 1 */ if(act_exponent <= f8_denormal_act_exponent) { /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. - For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 - actual exponent is -7, it is actually larger due to the implict 1, - Therefore it needs to be adjust to -6 and mantissa shift right by 1. - So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ exponent_diff = f8_denormal_act_exponent - act_exponent; } else @@ -221,10 +205,10 @@ this case, the fp16 mantissa should be shift left by 1 */ bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == (1 << (mfmt - wm + exponent_diff - 1)); /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we - shift right as shift right could rip off some residual part and make something not midpoint look - like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than - midpoint, but after shift right by 4 bits, it would look like midpoint. -*/ + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. + */ if(exponent_diff > 0) mantissa >>= exponent_diff; @@ -262,7 +246,6 @@ this case, the fp16 mantissa should be shift left by 1 */ // above range: quantize to maximum possible float of the same sign const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); - // TODO: this is ugly, need better way to handle out of range values if(f8_exponent > max_exp) { if(clip) @@ -271,15 +254,11 @@ this case, the fp16 mantissa should be shift left by 1 */ } else { + // https://onnx.ai/onnx/technical/float8.html#cast if(negative_zero_nan) - { return 0x80; - } else - { - uint32_t tmp_signed_max = (sign << 7) + ((((1 << we) - 1) << wm) + ((1 << wm) - 1)); - return (wm == 2) ? signed_inf : tmp_signed_max; - } + return (wm == 2) ? signed_inf : signed_all_ones; } } @@ -300,7 +279,7 @@ constexpr T cast_from_f8(uint8_t x) uint32_t ifNegInf = 0xFF800000; uint32_t ifNaN = 0x7F800001; uint32_t ifNeg0 = 0x80000000; - // TODO: need to change T for half but right now it would never called with half + fInf = detail::bit_cast(ifInf); fNegInf = detail::bit_cast(ifNegInf); fNaN = detail::bit_cast(ifNaN); diff --git a/src/include/migraphx/migraphx_float8.hpp b/src/include/migraphx/migraphx_float8.hpp index dd111bbc4d2..42e4a20123f 100644 --- a/src/include/migraphx/migraphx_float8.hpp +++ b/src/include/migraphx/migraphx_float8.hpp @@ -29,7 +29,9 @@ #pragma clang diagnostic ignored "-Wc++20-extensions" #endif // __clang__ -// We are clipping in down conversion by default +// We are clipping/saturation in down conversion by default. Unclipped version is not tested and +// shouldn't be used without having enough tests. +// logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 #include @@ -254,7 +256,8 @@ MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::float8) MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::float8) MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::float8) MIGRAPHX_FP8_BINARY_OP(+, migraphx_fp8::float8) -// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects. +// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding +// effects. MIGRAPHX_FP8_BINARY_OP(==, bool) MIGRAPHX_FP8_BINARY_OP(>=, bool) MIGRAPHX_FP8_BINARY_OP(<=, bool) From ab653aff2a793f04442579940f113dc2debe55ac Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 13 Nov 2023 19:36:54 +0000 Subject: [PATCH 010/115] Review updates --- src/api/include/migraphx/migraphx.h | 2 +- src/include/migraphx/bit_cast.hpp | 42 ++++++ .../{migraphx_float8.hpp => float8.hpp} | 138 +++++++++--------- .../{migraphx_f8_impl.hpp => float8_impl.hpp} | 90 ++++++------ src/include/migraphx/half.hpp | 6 +- src/include/migraphx/shape.hpp | 4 +- src/include/migraphx/type_traits.hpp | 18 +-- src/py/migraphx_py.cpp | 4 +- test/CMakeLists.txt | 2 +- test/float_equal.cpp | 18 +-- test/fp8e4m3fn.cpp | 46 +++--- test/fp8e4m3fnuz.cpp | 42 +++--- test/fp8e5m2.cpp | 46 +++--- test/fp8e5m2fnuz.cpp | 42 +++--- tools/api/migraphx.h | 2 +- 15 files changed, 261 insertions(+), 241 deletions(-) create mode 100644 src/include/migraphx/bit_cast.hpp rename src/include/migraphx/{migraphx_float8.hpp => float8.hpp} (80%) rename src/include/migraphx/{migraphx_f8_impl.hpp => float8_impl.hpp} (85%) diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index c8467c67f30..6a1edd9b6b0 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -45,7 +45,7 @@ m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ m(uint64_type, uint64_t) \ - m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz) + m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) // clang-format on #ifdef __cplusplus diff --git a/src/include/migraphx/bit_cast.hpp b/src/include/migraphx/bit_cast.hpp new file mode 100644 index 00000000000..45913608c1e --- /dev/null +++ b/src/include/migraphx/bit_cast.hpp @@ -0,0 +1,42 @@ +/* ************************************************************************ + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ +#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP +#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP +#include + +#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +template +inline constexpr To bit_cast(From fr) noexcept +{ + static_assert(sizeof(To) == sizeof(From)); +#if defined(__GNUC__) and !defined(__clang__) + return MIGRAPHX_CONST_FOLD(*reinterpret_cast(&fr)); +#else + return __builtin_bit_cast(To, fr); +#endif +} +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP diff --git a/src/include/migraphx/migraphx_float8.hpp b/src/include/migraphx/float8.hpp similarity index 80% rename from src/include/migraphx/migraphx_float8.hpp rename to src/include/migraphx/float8.hpp index 42e4a20123f..843ac89bd1d 100644 --- a/src/include/migraphx/migraphx_float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -44,20 +44,12 @@ #include #include #include +#include +#include -namespace migraphx_f8_impl { - -template -constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0); - -template -constexpr T cast_from_f8(uint8_t x); - -} // namespace migraphx_f8_impl - -#include - -namespace migraphx_fp8 { +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace fp8 { enum class migraphx_f8_rounding_mode { @@ -74,7 +66,7 @@ enum class f8_type template class numeric_limits; -template +template struct float8 { uint8_t data = 0x00; @@ -90,43 +82,43 @@ struct float8 explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} explicit constexpr float8(float v, - migraphx_fp8::migraphx_f8_rounding_mode rm = - migraphx_fp8::migraphx_f8_rounding_mode::standard, + migraphx::fp8::migraphx_f8_rounding_mode rm = + migraphx::fp8::migraphx_f8_rounding_mode::standard, uint32_t rng = 0) { - if constexpr(T == migraphx_fp8::f8_type::fp8) + if constexpr(T == migraphx::fp8::f8_type::fp8) { #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING - data = migraphx_f8_impl:: + data = migraphx::fp8::impl:: cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( - v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); + v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng); #else // MIGRAPHX_F8_DOWNCAST_CLIPPING - data = migraphx_f8_impl:: + data = migraphx::fp8::impl:: cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( - v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); + v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng); #endif // MIGRAPHX_F8_DOWNCAST_CLIPPING } else { #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING - data = migraphx_f8_impl:: + data = migraphx::fp8::impl:: cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( - v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); + v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng); #else // MIGRAPHX_F8_DOWNCAST_CLIPPING - data = migraphx_f8_impl:: + data = migraphx::fp8::impl:: cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( - v, (rm == migraphx_fp8::migraphx_f8_rounding_mode::stochastic), rng); + v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng); #endif // rocblas_F8_downcast_clipping} } } inline constexpr operator float() const { - if constexpr(T == migraphx_fp8::f8_type::fp8) + if constexpr(T == migraphx::fp8::f8_type::fp8) { - return migraphx_f8_impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data); + return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data); } // else - return migraphx_f8_impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); + return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); } inline constexpr bool is_zero() const @@ -149,7 +141,7 @@ struct float8 } else { - if(T == migraphx_fp8::f8_type::bf8) + if(T == migraphx::fp8::f8_type::bf8) { return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or (data == 0xFE) or (data == 0xFF); @@ -169,7 +161,7 @@ struct float8 } else { - if(T == migraphx_fp8::f8_type::bf8) + if(T == migraphx::fp8::f8_type::bf8) { return (data == 0x7C) or (data == 0xFC); } @@ -236,26 +228,26 @@ struct float8 }; // Special operator overloading -template -inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::float8& rhs) +template +inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8& rhs) { return os << static_cast(rhs); } // NOLINTNEXTLINE -#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ - template \ - inline constexpr U operator binary_op(const migraphx_fp8::float8& lhs, \ - const migraphx_fp8::float8& rhs) \ - { \ - return U(static_cast(lhs) binary_op static_cast(rhs)); \ +#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ + template \ + inline constexpr U operator binary_op(const migraphx::fp8::float8& lhs, \ + const migraphx::fp8::float8& rhs) \ + { \ + return U(static_cast(lhs) binary_op static_cast(rhs)); \ } // TODO: these should return floats -MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::float8) -MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::float8) -MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::float8) -MIGRAPHX_FP8_BINARY_OP(+, migraphx_fp8::float8) +MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8::float8) +MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8::float8) +MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8::float8) +MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8::float8) // TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding // effects. MIGRAPHX_FP8_BINARY_OP(==, bool) @@ -265,18 +257,18 @@ MIGRAPHX_FP8_BINARY_OP(>, bool) MIGRAPHX_FP8_BINARY_OP(<, bool) MIGRAPHX_FP8_BINARY_OP(!=, bool) -template -inline migraphx_fp8::float8 fabs(migraphx_fp8::float8 v) +template +inline migraphx::fp8::float8 fabs(migraphx::fp8::float8 v) { v.data = v.data & 0x7f; return v; } // https://onnx.ai/onnx/technical/float8.html -using fp8e4m3fn = float8; -using fp8e5m2 = float8; -using fp8e4m3fnuz = float8; -using fp8e5m2fnuz = float8; +using fp8e4m3fn = float8; +using fp8e5m2 = float8; +using fp8e4m3fnuz = float8; +using fp8e5m2fnuz = float8; template <> class numeric_limits @@ -347,37 +339,39 @@ class numeric_limits // 7C and FC both are infinity static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); } }; -} // namespace migraphx_fp8 +} // namespace fp8 +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx // ================================================================================================= // define numeric limits for the new data type namespace std { -#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ - inline bool isfinite(T x) { return x.is_inf(); } \ - inline bool isnan(T x) { return x.is_nan(); } \ - template <> \ - class numeric_limits : public migraphx_fp8::numeric_limits \ - { \ - }; \ - template \ - struct common_type : std::common_type \ - { \ - }; \ - template \ - struct common_type : std::common_type \ - { \ - }; \ - template <> \ - struct common_type \ - { \ - using type = T; \ +#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ + inline bool isfinite(T x) { return x.is_inf(); } \ + inline bool isnan(T x) { return x.is_nan(); } \ + template <> \ + class numeric_limits : public migraphx::fp8::numeric_limits \ + { \ + }; \ + template \ + struct common_type : std::common_type \ + { \ + }; \ + template \ + struct common_type : std::common_type \ + { \ + }; \ + template <> \ + struct common_type \ + { \ + using type = T; \ }; -MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fn) -MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2) -MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e4m3fnuz) -MIGRAPHX_FP8_STD_OVERLOADS(migraphx_fp8::fp8e5m2fnuz) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz) } // namespace std // ================================================================================================= diff --git a/src/include/migraphx/migraphx_f8_impl.hpp b/src/include/migraphx/float8_impl.hpp similarity index 85% rename from src/include/migraphx/migraphx_f8_impl.hpp rename to src/include/migraphx/float8_impl.hpp index 91871fe643e..70642d1621b 100644 --- a/src/include/migraphx/migraphx_f8_impl.hpp +++ b/src/include/migraphx/float8_impl.hpp @@ -20,49 +20,32 @@ * * ************************************************************************ */ -#ifndef MIGRAPHX_FP8_IMPL_HPP -#define MIGRAPHX_FP8_IMPL_HPP - -#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) -namespace migraphx_f8_impl { -namespace detail { -template -struct conditional -{ - using type = T; -}; - -template -struct conditional -{ - using type = F; -}; - -template -inline constexpr To bit_cast(From fr) noexcept -{ - static_assert(sizeof(To) == sizeof(From)); -#if defined(__GNUC__) and !defined(__clang__) - return MIGRAPHX_CONST_FOLD(*reinterpret_cast(&fr)); -#else - return __builtin_bit_cast(To, fr); -#endif -} -} // namespace detail +#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP +#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP +#include +#include +#include +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace fp8 { +namespace impl { template -constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) +constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0) { - + constexpr bool is_float = std::is_same::value; + // half is not supported for now + constexpr bool is_half = false; static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_float or is_half, "Only float can be cast to f8"); const int mfmt = (sizeof(T) == 4) ? 23 : 10; - typename detail::conditional::type x; + typename std::conditional::type x; if constexpr(sizeof(T) == 4) - x = detail::bit_cast(_x); + x = migraphx::bit_cast(_x); else - x = detail::bit_cast(_x); + x = migraphx::bit_cast(_x); uint32_t head, mantissa; int exponent, bias; @@ -271,19 +254,27 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) template constexpr T cast_from_f8(uint8_t x) { - constexpr int weo = 8; - constexpr int wmo = 23; + // half is not supported for now + constexpr bool is_half = false; + constexpr bool is_float = std::is_same::value; + static_assert(is_float or is_half, "Only float are supported"); + + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); T fInf, fNegInf, fNaN, fNeg0; - uint32_t ifInf = 0x7F800000; - uint32_t ifNegInf = 0xFF800000; - uint32_t ifNaN = 0x7F800001; - uint32_t ifNeg0 = 0x80000000; - fInf = detail::bit_cast(ifInf); - fNegInf = detail::bit_cast(ifNegInf); - fNaN = detail::bit_cast(ifNaN); - fNeg0 = detail::bit_cast(ifNeg0); + if constexpr(is_float) + { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = migraphx::bit_cast(ifInf); + fNegInf = migraphx::bit_cast(ifNegInf); + fNaN = migraphx::bit_cast(ifNaN); + fNeg0 = migraphx::bit_cast(ifNeg0); + } if(x == 0) return 0; @@ -305,7 +296,7 @@ constexpr T cast_from_f8(uint8_t x) else if(wm == 3 and (x == 0x7F or x == 0xFF)) return fNaN; } - typename detail::conditional::type retval; + typename std::conditional::type retval; const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); @@ -333,8 +324,11 @@ constexpr T cast_from_f8(uint8_t x) retval = (sign << 15) | (exponent << 10) | mantissa; else retval = (sign << 31) | (exponent << 23) | mantissa; - return detail::bit_cast(retval); + return migraphx::bit_cast(retval); } -} // namespace migraphx_f8_impl -#endif // MIGRAPHX_FP8_IMPL_HPP +} // namespace impl +} // namespace fp8 +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL diff --git a/src/include/migraphx/half.hpp b/src/include/migraphx/half.hpp index de692dd16e3..0f6516d9bda 100644 --- a/src/include/migraphx/half.hpp +++ b/src/include/migraphx/half.hpp @@ -27,7 +27,7 @@ #include #include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -69,13 +69,13 @@ struct common_type : std::common_type // NOLINT }; template <> -struct common_type +struct common_type { using type = float; }; template <> -struct common_type +struct common_type { using type = float; }; diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index ef82f300226..d596398ca78 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -34,7 +34,7 @@ #include #include #include -#include +#include #include #include @@ -62,7 +62,7 @@ struct MIGRAPHX_EXPORT shape m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ m(uint64_type, uint64_t) \ - m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz) + m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) // clang-format on #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, diff --git a/src/include/migraphx/type_traits.hpp b/src/include/migraphx/type_traits.hpp index 8fc3081ef18..44b5e0573cc 100644 --- a/src/include/migraphx/type_traits.hpp +++ b/src/include/migraphx/type_traits.hpp @@ -28,7 +28,7 @@ #include #include #include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -49,23 +49,13 @@ MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point); MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic); MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed); -template -struct is_same : std::is_same -{ -}; - -template -struct conditional : std::conditional -{ -}; - MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) -MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx_fp8::fp8e4m3fnuz) -MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx_fp8::fp8e4m3fnuz) -MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx_fp8::fp8e4m3fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz) template using accumulator_type = diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 4014936ef44..91af6cf9ded 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -40,7 +40,7 @@ #include #include #include -#include +#include #ifdef HAVE_GPU #include #endif @@ -145,7 +145,7 @@ struct npy_format_descriptor }; template <> -struct npy_format_descriptor +struct npy_format_descriptor { static std::string format() { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a6217b0ec2f..33aca123217 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -150,7 +150,7 @@ function(test_headers PREFIX) list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp) endif() - list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/migraphx_f8_impl.hpp) + list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/float8_impl.hpp) foreach(HEADER ${HEADERS}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) diff --git a/test/float_equal.cpp b/test/float_equal.cpp index 0ae10614708..847a929437c 100644 --- a/test/float_equal.cpp +++ b/test/float_equal.cpp @@ -22,7 +22,7 @@ * THE SOFTWARE. */ #include -#include +#include #include #include "test.hpp" @@ -72,12 +72,12 @@ void test_equality() TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); -TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); -TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); -TEST_CASE_REGISTER(test_equality); -TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); template void test_limits() @@ -115,12 +115,12 @@ void test_limits() TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); -TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); -TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); -TEST_CASE_REGISTER(test_limits); -TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); #ifndef _WIN32 // On Windows, types int and long have the same min and max values. diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp index 0601ac696c5..bef04c01d2c 100644 --- a/test/fp8e4m3fn.cpp +++ b/test/fp8e4m3fn.cpp @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include "test.hpp" @@ -108,7 +108,7 @@ TEST_CASE(test_fp8_cast_to_float) std::vector bit_vals(256); std::iota(bit_vals.begin(), bit_vals.end(), 0); EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { - migraphx_fp8::fp8e4m3fn fp8_val(bit_val, migraphx_fp8::fp8e4m3fn::from_bits()); + migraphx::fp8::fp8e4m3fn fp8_val(bit_val, migraphx::fp8::fp8e4m3fn::from_bits()); if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val))) { return true; @@ -120,7 +120,7 @@ TEST_CASE(test_fp8_cast_to_float) TEST_CASE(test_positive_zero) { float zero = 0.0; - migraphx_fp8::fp8e4m3fn fp8_zero(zero); + migraphx::fp8::fp8e4m3fn fp8_zero(zero); EXPECT(fp8_zero.is_zero()); EXPECT(migraphx::float_equal(zero, float(fp8_zero))); } @@ -128,7 +128,7 @@ TEST_CASE(test_positive_zero) TEST_CASE(test_negative_zero) { float nzero = -0.0; - migraphx_fp8::fp8e4m3fn fp8_nzero(nzero); + migraphx::fp8::fp8e4m3fn fp8_nzero(nzero); EXPECT(fp8_nzero.is_zero()); // negative zero is preserved for fp8e4m3fn EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); @@ -137,15 +137,15 @@ TEST_CASE(test_negative_zero) TEST_CASE(test_nan_1) { float fnan = std::numeric_limits::quiet_NaN(); - migraphx_fp8::fp8e4m3fn fp8_nan(fnan); + migraphx::fp8::fp8e4m3fn fp8_nan(fnan); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(fp8_nan)); } TEST_CASE(test_nan_2) { - auto fnan = std::numeric_limits::quiet_NaN(); - migraphx_fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fn::from_bits()); + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fn::from_bits()); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(float(fp8_nan))); @@ -155,8 +155,8 @@ TEST_CASE(test_infinity_1) { float finf = std::numeric_limits::infinity(); // no inf in fp8e4m3fn, it gets clipped to max() - migraphx_fp8::fp8e4m3fn fp8_max(finf); - EXPECT(fp8_max == std::numeric_limits::max()); + migraphx::fp8::fp8e4m3fn fp8_max(finf); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_infinity_2) @@ -164,43 +164,43 @@ TEST_CASE(test_infinity_2) // neg inf float finf = -1.0 * std::numeric_limits::infinity(); // no inf in fp8e4m3fn, it gets clipped to lowest - migraphx_fp8::fp8e4m3fn fp8_lowest(finf); - EXPECT(bool{fp8_lowest == std::numeric_limits::lowest()}); + migraphx::fp8::fp8e4m3fn fp8_lowest(finf); + EXPECT(bool{fp8_lowest == std::numeric_limits::lowest()}); } TEST_CASE(test_numeric_max_1) { float fmax = std::numeric_limits::max(); - migraphx_fp8::fp8e4m3fn fp8_max(fmax); - EXPECT(fp8_max == std::numeric_limits::max()); + migraphx::fp8::fp8e4m3fn fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_numeric_max_2) { // gets clipped to max - float fmax = 2 * std::numeric_limits::max(); - migraphx_fp8::fp8e4m3fn fp8_max(fmax); - EXPECT(fp8_max == std::numeric_limits::max()); + float fmax = 2 * std::numeric_limits::max(); + migraphx::fp8::fp8e4m3fn fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_numeric_lowest_1) { float flowest = std::numeric_limits::lowest(); - migraphx_fp8::fp8e4m3fn fp8_lowest(flowest); - EXPECT(fp8_lowest == std::numeric_limits::lowest()); + migraphx::fp8::fp8e4m3fn fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); } TEST_CASE(test_numeric_lowest_2) { // gets clipped to lowest - float fmin = 2.0 * std::numeric_limits::lowest(); - migraphx_fp8::fp8e4m3fn fp8_lowest(fmin); - EXPECT(fp8_lowest == std::numeric_limits::lowest()); + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx::fp8::fp8e4m3fn fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); } TEST_CASE(test_max_eq_lowest) { - EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), - -1 * std::numeric_limits::max())); + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp index d1e3cd72542..11ab4ada335 100644 --- a/test/fp8e4m3fnuz.cpp +++ b/test/fp8e4m3fnuz.cpp @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include "test.hpp" @@ -129,7 +129,7 @@ TEST_CASE(test_fp8_cast_to_float) std::vector bit_vals(256); std::iota(bit_vals.begin(), bit_vals.end(), 0); EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { - migraphx_fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx_fp8::fp8e4m3fnuz::from_bits()); + migraphx::fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx::fp8::fp8e4m3fnuz::from_bits()); if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fnuz_to_fp32_value(bit_val))) { return true; @@ -141,7 +141,7 @@ TEST_CASE(test_fp8_cast_to_float) TEST_CASE(test_positive_zero) { float zero = 0.0; - migraphx_fp8::fp8e4m3fnuz fp8_zero(zero); + migraphx::fp8::fp8e4m3fnuz fp8_zero(zero); EXPECT(fp8_zero.is_zero()); EXPECT(migraphx::float_equal(zero, float(fp8_zero))); } @@ -150,7 +150,7 @@ TEST_CASE(test_negative_zero) { float nzero = -0.0; float pzero = 0.0; - migraphx_fp8::fp8e4m3fnuz fp8_nzero(nzero); + migraphx::fp8::fp8e4m3fnuz fp8_nzero(nzero); EXPECT(fp8_nzero.is_zero()); // negative zero gets converted to positive zero EXPECT(migraphx::float_equal(pzero, float(fp8_nzero))); @@ -159,15 +159,15 @@ TEST_CASE(test_negative_zero) TEST_CASE(test_nan_1) { float fnan = std::numeric_limits::quiet_NaN(); - migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan); + migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(fp8_nan)); } TEST_CASE(test_nan_2) { - auto fnan = std::numeric_limits::quiet_NaN(); - migraphx_fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e4m3fnuz::from_bits()); + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fnuz::from_bits()); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(float(fp8_nan))); @@ -177,7 +177,7 @@ TEST_CASE(test_infinity_1) { float finf = std::numeric_limits::infinity(); // no inf in fp8e4m3fnuz it gets clipped to Nans - migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); + migraphx::fp8::fp8e4m3fnuz fp8_nan(finf); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(float(fp8_nan))); } @@ -187,7 +187,7 @@ TEST_CASE(test_infinity_2) // neg inf float finf = -1.0 * std::numeric_limits::infinity(); // no inf in fp8e4m3fnuz it gets clipped to NaNs - migraphx_fp8::fp8e4m3fnuz fp8_nan(finf); + migraphx::fp8::fp8e4m3fnuz fp8_nan(finf); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(float(fp8_nan))); } @@ -195,36 +195,36 @@ TEST_CASE(test_infinity_2) TEST_CASE(test_numeric_max_1) { float fmax = std::numeric_limits::max(); - migraphx_fp8::fp8e4m3fnuz fp8_max(fmax); - EXPECT(fp8_max == std::numeric_limits::max()); + migraphx::fp8::fp8e4m3fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_numeric_max_2) { // gets clipped to max - float fmax = 2 * std::numeric_limits::max(); - migraphx_fp8::fp8e4m3fnuz fp8_max(fmax); - EXPECT(fp8_max == std::numeric_limits::max()); + float fmax = 2 * std::numeric_limits::max(); + migraphx::fp8::fp8e4m3fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_numeric_lowest_1) { float flowest = std::numeric_limits::lowest(); - migraphx_fp8::fp8e4m3fnuz fp8_lowest(flowest); - EXPECT(fp8_lowest == std::numeric_limits::lowest()); + migraphx::fp8::fp8e4m3fnuz fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); } TEST_CASE(test_numeric_lowest_2) { // gets clipped to lowest - float fmin = 2.0 * std::numeric_limits::lowest(); - migraphx_fp8::fp8e4m3fnuz fp8_lowest(fmin); - EXPECT(fp8_lowest == std::numeric_limits::lowest()); + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx::fp8::fp8e4m3fnuz fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); } TEST_CASE(test_max_eq_lowest) { - EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), - -1 * std::numeric_limits::max())); + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index b4916ec35e8..8e2c1a68362 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include "test.hpp" @@ -301,7 +301,7 @@ TEST_CASE(test_fp8_cast_to_float) std::vector bit_vals(256); std::iota(bit_vals.begin(), bit_vals.end(), 0); EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { - migraphx_fp8::fp8e5m2 fp8_val(bit_val, migraphx_fp8::fp8e5m2::from_bits()); + migraphx::fp8::fp8e5m2 fp8_val(bit_val, migraphx::fp8::fp8e5m2::from_bits()); if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2_to_fp32_value(bit_val))) { return true; @@ -317,7 +317,7 @@ TEST_CASE(test_fp8_cast_to_float) TEST_CASE(test_positive_zero) { float zero = 0.0; - migraphx_fp8::fp8e5m2 fp8_zero(zero); + migraphx::fp8::fp8e5m2 fp8_zero(zero); EXPECT(fp8_zero.is_zero()); EXPECT(migraphx::float_equal(zero, float(fp8_zero))); } @@ -325,7 +325,7 @@ TEST_CASE(test_positive_zero) TEST_CASE(test_negative_zero) { float nzero = -0.0; - migraphx_fp8::fp8e5m2 fp8_nzero(nzero); + migraphx::fp8::fp8e5m2 fp8_nzero(nzero); EXPECT(fp8_nzero.is_zero()); // negative zero is preserved for fp8e5m2 EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); @@ -334,15 +334,15 @@ TEST_CASE(test_negative_zero) TEST_CASE(test_nan_1) { float fnan = std::numeric_limits::quiet_NaN(); - migraphx_fp8::fp8e5m2 fp8_nan(fnan); + migraphx::fp8::fp8e5m2 fp8_nan(fnan); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(fp8_nan)); } TEST_CASE(test_nan_2) { - auto fnan = std::numeric_limits::quiet_NaN(); - migraphx_fp8::fp8e5m2 fp8_nan(fnan.data, migraphx_fp8::fp8e5m2::from_bits()); + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e5m2 fp8_nan(fnan.data, migraphx::fp8::fp8e5m2::from_bits()); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(float(fp8_nan))); @@ -352,8 +352,8 @@ TEST_CASE(test_infinity_1) { // float infinity should get clipped to max float finf = std::numeric_limits::infinity(); - migraphx_fp8::fp8e5m2 fp8_max(finf); - EXPECT(fp8_max == std::numeric_limits::max()); + migraphx::fp8::fp8e5m2 fp8_max(finf); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_infinity_2) @@ -361,43 +361,43 @@ TEST_CASE(test_infinity_2) // neg inf float finf = -1.0 * std::numeric_limits::infinity(); // no inf in fp8e5m2, it gets clipped to lowest - migraphx_fp8::fp8e5m2 fp8_lowest(finf); - EXPECT(bool{fp8_lowest == std::numeric_limits::lowest()}); + migraphx::fp8::fp8e5m2 fp8_lowest(finf); + EXPECT(bool{fp8_lowest == std::numeric_limits::lowest()}); } TEST_CASE(test_numeric_max_1) { float fmax = std::numeric_limits::max(); - migraphx_fp8::fp8e5m2 fp8_max(fmax); - EXPECT(fp8_max == std::numeric_limits::max()); + migraphx::fp8::fp8e5m2 fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_numeric_max_2) { // gets clipped to max - float fmax = 2 * std::numeric_limits::max(); - migraphx_fp8::fp8e5m2 fp8_max(fmax); - EXPECT(fp8_max == std::numeric_limits::max()); + float fmax = 2 * std::numeric_limits::max(); + migraphx::fp8::fp8e5m2 fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_numeric_lowest_1) { float flowest = std::numeric_limits::lowest(); - migraphx_fp8::fp8e5m2 fp8_lowest(flowest); - EXPECT(fp8_lowest == std::numeric_limits::lowest()); + migraphx::fp8::fp8e5m2 fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); } TEST_CASE(test_numeric_lowest_2) { // gets clipped to lowest - float fmin = 2.0 * std::numeric_limits::lowest(); - migraphx_fp8::fp8e5m2 fp8_lowest(fmin); - EXPECT(fp8_lowest == std::numeric_limits::lowest()); + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx::fp8::fp8e5m2 fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); } TEST_CASE(test_max_eq_lowest) { - EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), - -1 * std::numeric_limits::max())); + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2fnuz.cpp b/test/fp8e5m2fnuz.cpp index da447ab1996..06f422672ba 100644 --- a/test/fp8e5m2fnuz.cpp +++ b/test/fp8e5m2fnuz.cpp @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include "test.hpp" @@ -299,7 +299,7 @@ TEST_CASE(test_fp8_cast_to_float) std::vector bit_vals(256); std::iota(bit_vals.begin(), bit_vals.end(), 0); EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { - migraphx_fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx_fp8::fp8e5m2fnuz::from_bits()); + migraphx::fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx::fp8::fp8e5m2fnuz::from_bits()); if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2fnuz_to_fp32_value(bit_val))) { return true; @@ -311,7 +311,7 @@ TEST_CASE(test_fp8_cast_to_float) TEST_CASE(test_positive_zero) { float zero = 0.0; - migraphx_fp8::fp8e5m2fnuz fp8_zero(zero); + migraphx::fp8::fp8e5m2fnuz fp8_zero(zero); EXPECT(fp8_zero.is_zero()); EXPECT(migraphx::float_equal(zero, float(fp8_zero))); } @@ -320,7 +320,7 @@ TEST_CASE(test_negative_zero) { float nzero = -0.0; float pzero = 0.0; - migraphx_fp8::fp8e5m2fnuz fp8_nzero(nzero); + migraphx::fp8::fp8e5m2fnuz fp8_nzero(nzero); EXPECT(fp8_nzero.is_zero()); // negative zero gets converted to positive zero EXPECT(migraphx::float_equal(pzero, float(fp8_nzero))); @@ -329,15 +329,15 @@ TEST_CASE(test_negative_zero) TEST_CASE(test_nan_1) { float fnan = std::numeric_limits::quiet_NaN(); - migraphx_fp8::fp8e5m2fnuz fp8_nan(fnan); + migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(fp8_nan)); } TEST_CASE(test_nan_2) { - auto fnan = std::numeric_limits::quiet_NaN(); - migraphx_fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx_fp8::fp8e5m2fnuz::from_bits()); + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e5m2fnuz::from_bits()); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(fp8_nan)); EXPECT(std::isnan(float(fp8_nan))); @@ -347,7 +347,7 @@ TEST_CASE(test_infinity_1) { float finf = std::numeric_limits::infinity(); // no inf in fp8e5m2fnuz it gets clipped to Nans - migraphx_fp8::fp8e5m2fnuz fp8_nan(finf); + migraphx::fp8::fp8e5m2fnuz fp8_nan(finf); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(float(fp8_nan))); } @@ -357,7 +357,7 @@ TEST_CASE(test_infinity_2) // neg inf float finf = -1.0 * std::numeric_limits::infinity(); // no inf in fp8e5m2fnuz it gets clipped to NaNs - migraphx_fp8::fp8e5m2fnuz fp8_nan(finf); + migraphx::fp8::fp8e5m2fnuz fp8_nan(finf); EXPECT(fp8_nan.is_nan()); EXPECT(std::isnan(float(fp8_nan))); } @@ -365,36 +365,36 @@ TEST_CASE(test_infinity_2) TEST_CASE(test_numeric_max_1) { float fmax = std::numeric_limits::max(); - migraphx_fp8::fp8e5m2fnuz fp8_max(fmax); - EXPECT(fp8_max == std::numeric_limits::max()); + migraphx::fp8::fp8e5m2fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_numeric_max_2) { // gets clipped to max - float fmax = 2 * std::numeric_limits::max(); - migraphx_fp8::fp8e5m2fnuz fp8_max(fmax); - EXPECT(fp8_max == std::numeric_limits::max()); + float fmax = 2 * std::numeric_limits::max(); + migraphx::fp8::fp8e5m2fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); } TEST_CASE(test_numeric_lowest_1) { float flowest = std::numeric_limits::lowest(); - migraphx_fp8::fp8e5m2fnuz fp8_lowest(flowest); - EXPECT(fp8_lowest == std::numeric_limits::lowest()); + migraphx::fp8::fp8e5m2fnuz fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); } TEST_CASE(test_numeric_lowest_2) { // gets clipped to lowest - float fmin = 2.0 * std::numeric_limits::lowest(); - migraphx_fp8::fp8e5m2fnuz fp8_lowest(fmin); - EXPECT(fp8_lowest == std::numeric_limits::lowest()); + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx::fp8::fp8e5m2fnuz fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); } TEST_CASE(test_max_eq_lowest) { - EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), - -1 * std::numeric_limits::max())); + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/tools/api/migraphx.h b/tools/api/migraphx.h index ad441280bb7..57441279b18 100644 --- a/tools/api/migraphx.h +++ b/tools/api/migraphx.h @@ -45,7 +45,7 @@ m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ m(uint64_type, uint64_t) \ - m(fp8e4m3fnuz_type, migraphx_fp8::fp8e4m3fnuz) + m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) // clang-format on #ifdef __cplusplus From 8319e01f68c3099cb05d18b598f92dfe820de95f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 14 Nov 2023 00:04:40 +0000 Subject: [PATCH 011/115] Fix tidy --- src/include/migraphx/bit_cast.hpp | 8 + src/include/migraphx/float8.hpp | 29 ++-- src/include/migraphx/float8_impl.hpp | 219 ++++++++++++--------------- src/targets/gpu/gemm_impl.cpp | 1 + test/fp8e4m3fn.cpp | 9 ++ test/fp8e5m2.cpp | 9 ++ 6 files changed, 141 insertions(+), 134 deletions(-) diff --git a/src/include/migraphx/bit_cast.hpp b/src/include/migraphx/bit_cast.hpp index 45913608c1e..b5fb6d472f6 100644 --- a/src/include/migraphx/bit_cast.hpp +++ b/src/include/migraphx/bit_cast.hpp @@ -21,8 +21,13 @@ * ************************************************************************ */ #ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP #define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif #include +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) #define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) namespace migraphx { @@ -39,4 +44,7 @@ inline constexpr To bit_cast(From fr) noexcept } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif #endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp index 843ac89bd1d..3303b09d3fb 100644 --- a/src/include/migraphx/float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -32,6 +32,7 @@ // We are clipping/saturation in down conversion by default. Unclipped version is not tested and // shouldn't be used without having enough tests. // logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast +// NOLINTNEXTLINE #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 #include @@ -173,6 +174,7 @@ struct float8 } } +// NOLINTNEXTLINE #define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \ constexpr float8& operator unary_op(const float8& rhs) \ { \ @@ -192,8 +194,8 @@ struct float8 MIGRAPHX_FP8_UNARY_OP(+=, +) MIGRAPHX_FP8_UNARY_OP(/=, /) - inline constexpr float8& operator=(const float8& rhs) = default; - inline constexpr float8& operator=(float8&& rhs) = default; + inline constexpr float8& operator=(const float8& rhs) = default; + inline constexpr float8& operator=(float8&& rhs) noexcept = default; inline constexpr float8& operator=(float rhs) { @@ -203,11 +205,9 @@ struct float8 inline constexpr bool operator==(const float8& rhs) const { - if(rhs.is_zero() and this->is_zero()) - return true; - else if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) + if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) return false; - else if(this->data == rhs.data) + else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data)) return true; return false; } @@ -260,7 +260,7 @@ MIGRAPHX_FP8_BINARY_OP(!=, bool) template inline migraphx::fp8::float8 fabs(migraphx::fp8::float8 v) { - v.data = v.data & 0x7f; + v.data = v.data & 0x7f; // NOLINT return v; } @@ -277,7 +277,7 @@ class numeric_limits public: static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); } - + // NOLINTNEXTLINE static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } static constexpr fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); } @@ -294,7 +294,7 @@ class numeric_limits public: static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); } - + // NOLINTNEXTLINE static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); } static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); } @@ -312,7 +312,10 @@ class numeric_limits public: static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } - static constexpr fp8e5m2fnuz quiet_NaN() { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } + static constexpr fp8e5m2fnuz quiet_NaN() // NOLINT + { + return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); + } static constexpr fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make @@ -328,7 +331,7 @@ class numeric_limits public: static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); } // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs - static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } + static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } // NOLINT static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make @@ -345,8 +348,8 @@ class numeric_limits // ================================================================================================= // define numeric limits for the new data type +// NOLINTBEGIN namespace std { - #define MIGRAPHX_FP8_STD_OVERLOADS(T) \ inline bool isfinite(T x) { return x.is_inf(); } \ inline bool isnan(T x) { return x.is_nan(); } \ @@ -372,8 +375,8 @@ MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz) MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz) - } // namespace std +// NOLINTEND // ================================================================================================= #if defined(__clang__) #pragma clang diagnostic pop diff --git a/src/include/migraphx/float8_impl.hpp b/src/include/migraphx/float8_impl.hpp index 70642d1621b..8139e3de482 100644 --- a/src/include/migraphx/float8_impl.hpp +++ b/src/include/migraphx/float8_impl.hpp @@ -30,111 +30,91 @@ inline namespace MIGRAPHX_INLINE_NS { namespace fp8 { namespace impl { -template -constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0) +// NOLINTBEGIN +template +constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) { constexpr bool is_float = std::is_same::value; // half is not supported for now constexpr bool is_half = false; - static_assert(wm + we == 7, "wm+we==7"); + static_assert(Wm + We == 7, "Wm+We==7"); static_assert(is_float or is_half, "Only float can be cast to f8"); - const int mfmt = (sizeof(T) == 4) ? 23 : 10; + const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10; typename std::conditional::type x; if constexpr(sizeof(T) == 4) - x = migraphx::bit_cast(_x); + x = migraphx::bit_cast(f_x); else - x = migraphx::bit_cast(_x); - - uint32_t head, mantissa; - int exponent, bias; - uint32_t sign; + x = migraphx::bit_cast(f_x); + uint32_t head = 0; + uint32_t mantissa = 0; + int exponent = 0; + uint32_t bias = 0; + uint32_t sign = 0; if constexpr(sizeof(T) == 4) { - head = x & 0xFF800000; - mantissa = x & 0x7FFFFF; - exponent = (head >> 23) & 0xFF; - sign = head >> 31; + head = x & 0xFF800000; // NOLINT + mantissa = x & 0x7FFFFF; // NOLINT + exponent = (head >> 23) & 0xFF; // NOLINT + sign = head >> 31; // NOLINT bias = 127; } else { - head = x & 0xFC00; - mantissa = x & 0x3FF; - exponent = (head >> 10) & 0x1F; - sign = head >> 15; + head = x & 0xFC00; // NOLINT + mantissa = x & 0x3FF; // NOLINT + exponent = (head >> 10) & 0x1F; // NOLINT + sign = head >> 15; // NOLINT bias = 15; } - uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); - uint32_t signed_all_ones = (sign << 7) + ((((1 << we) - 1) << wm) + ((1 << wm) - 1)); + uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm); // NOLINT + uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1)); // NOLINT // Calcualte maximum singed value FLT_MAX, FLT_MIN uint32_t signed_max = signed_all_ones; - if(not negative_zero_nan) - { - signed_max = (wm == 2) ? (signed_max - 4) : (signed_max - 1); - } + if(not NegativeZeroNan) + signed_max = (Wm == 2) ? (signed_max - 4) : (signed_max - 1); // Deal with inf and NaNs - if(negative_zero_nan) // For the FNUZ cases, it is simple just return NaNs + if(NegativeZeroNan) // For the FNUZ cases, it is simple just return NaNs { - if(sizeof(T) == 4) - { - if((x & 0x7F800000) == 0x7F800000) - return 0x80; - } - else - { - if((x & 0x7C00) == 0x7C00) - return 0x80; - } + if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or // NOLINT + (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) // NOLINT + return 0x80; } else { // calculate most common NaN mantissa for FP8, which is all Ones in binary uint32_t nan_mantissa = 1; - for(auto i = 1; i < wm; ++i) + for(auto i = 1; i < Wm; ++i) { - nan_mantissa |= (nan_mantissa << 1); + nan_mantissa |= (nan_mantissa << 1); // NOLINT } - if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or - (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) + if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or // NOLINT + (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) // NOLINT { // infinity if(mantissa == 0) { if(sign == 0) - { - return (wm == 2) ? 0x7B : 0x7E; - } + return (Wm == 2) ? 0x7B : 0x7E; else - { - return (wm == 2) ? 0xFB : 0xFE; - } + return (Wm == 2) ? 0xFB : 0xFE; } - else - { // NaNs + else // NaNs return signed_inf + nan_mantissa; - } } } // handle positive zero if(x == 0) return 0; // handle negative zero - if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) + else if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) { - if(negative_zero_nan) // For FNUZ types neg zero is just positive zero - { - return 0; - } - else - { - return 0x80; - } + return NegativeZeroNan ? 0 : 0x80; // For FNUZ types neg zero is just positive zero } /* First need to check if it is normal or denorm as there is a difference of implict 1 @@ -144,13 +124,15 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0) exponent and mantissa again*/ // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits - const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0); // NOLINT const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal /* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) f8_exponent is the converted f8 exponent with bias encoding exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, the difference needs to be adjusted and mantissa shifted*/ - int act_exponent, f8_exponent, exponent_diff; + int act_exponent = 0; + int f8_exponent = 0; + int exponent_diff = 0; if(exponent == 0) { // fp32/fp16 is in denormal. @@ -182,11 +164,11 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0) 0; // exponent_diff=0 does not mean there is no difference for this case, // act_exponent could be larger. Just that it does not need shift mantissa } - mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + mantissa += (1u << mfmt); // Add the implicit 1 into mantissa } - - bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == - (1 << (mfmt - wm + exponent_diff - 1)); + // NOLINTNEXTLINE + bool midpoint = (mantissa & ((1 << (mfmt - Wm + exponent_diff)) - 1)) == + (1 << (mfmt - Wm + exponent_diff - 1)); // NOLINT /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift right as shift right could rip off some residual part and make something not midpoint look like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than @@ -194,64 +176,58 @@ constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0) */ if(exponent_diff > 0) - mantissa >>= exponent_diff; + mantissa >>= exponent_diff; // NOLINT else if(exponent_diff == -1) - mantissa <<= -exponent_diff; - bool implicit_one = mantissa & (1 << mfmt); + mantissa <<= -exponent_diff; // NOLINT + bool implicit_one = mantissa & (1 << mfmt); // NOLINT // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); // Now we have the exponent and mantissa adjusted - uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + uint32_t drop_mask = (1u << (mfmt - Wm)) - 1; // NOLINT bool odd = - mantissa & (1 << (mfmt - wm)); // if the least significant bit that is not truncated is 1 - mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + mantissa & (1u << (mfmt - Wm)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & // NOLINT + drop_mask; // NOLINT // Now we deal with overflow - if(f8_exponent == 0) + if(f8_exponent == 0 and ((1 << mfmt) & mantissa)) // NOLINT { - if((1 << mfmt) & mantissa) - { - f8_exponent = 1; // denormal overflow to become normal, promote exponent - } + f8_exponent = 1; // denormal overflow to become normal, promote exponent } - else + else if((1 << (mfmt + 1)) & mantissa) // NOLINT { - if((1 << (mfmt + 1)) & mantissa) - { - mantissa >>= 1; - f8_exponent++; - } + mantissa >>= 1; // NOLINT + f8_exponent++; } - mantissa >>= (mfmt - wm); + mantissa >>= (mfmt - Wm); // NOLINT // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + const int max_exp = (1 << We) - (NegativeZeroNan ? 1 : 2); // NOLINT if(f8_exponent > max_exp) { - if(clip) - { + if(Clip) return signed_max; - } else { // https://onnx.ai/onnx/technical/float8.html#cast - if(negative_zero_nan) + if(NegativeZeroNan) return 0x80; else - return (wm == 2) ? signed_inf : signed_all_ones; + return (Wm == 2) ? signed_inf : signed_all_ones; } } if(f8_exponent == 0 and mantissa == 0) - return negative_zero_nan ? 0 : (sign << 7); - mantissa &= (1 << wm) - 1; - return (sign << 7) | (f8_exponent << wm) | mantissa; + return NegativeZeroNan ? 0 : (sign << 7); // NOLINT + mantissa &= (1 << Wm) - 1; // NOLINT + return (sign << 7) | (f8_exponent << Wm) | mantissa; // NOLINT } +// NOLINTEND -template +template constexpr T cast_from_f8(uint8_t x) { // half is not supported for now @@ -261,69 +237,70 @@ constexpr T cast_from_f8(uint8_t x) constexpr int weo = is_half ? 5 : 8; constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); - - T fInf, fNegInf, fNaN, fNeg0; + // NOLINTNEXTLINE + T f_inf, f_neg_inf, f_nan, f_neg0; if constexpr(is_float) { - const uint32_t ifInf = 0x7F800000; - const uint32_t ifNegInf = 0xFF800000; - const uint32_t ifNaN = 0x7F800001; - const uint32_t ifNeg0 = 0x80000000; - fInf = migraphx::bit_cast(ifInf); - fNegInf = migraphx::bit_cast(ifNegInf); - fNaN = migraphx::bit_cast(ifNaN); - fNeg0 = migraphx::bit_cast(ifNeg0); + const uint32_t if_inf = 0x7F800000; + const uint32_t if_neg_inf = 0xFF800000; + const uint32_t if_nan = 0x7F800001; + const uint32_t if_neg0 = 0x80000000; + f_inf = migraphx::bit_cast(if_inf); + f_neg_inf = migraphx::bit_cast(if_neg_inf); + f_nan = migraphx::bit_cast(if_nan); + f_neg0 = migraphx::bit_cast(if_neg0); } if(x == 0) return 0; - uint32_t sign = x >> 7; - uint32_t mantissa = x & ((1 << wm) - 1); - int exponent = (x & 0x7F) >> wm; - if(negative_zero_nan) + uint32_t sign = x >> 7; // NOLINT + uint32_t mantissa = x & ((1 << Wm) - 1); // NOLINT + int exponent = (x & 0x7F) >> Wm; // NOLINT + if(NegativeZeroNan) { if(x == 0x80) - return fNaN; + return f_nan; } else { if(x == 0x80) - return fNeg0; - if(exponent == ((1 << we) - 1) and wm == 2) - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; - else if(wm == 3 and (x == 0x7F or x == 0xFF)) - return fNaN; + return f_neg0; + if(exponent == ((1 << We) - 1) and Wm == 2) // NOLINT + return (mantissa == 0) ? (sign ? f_neg_inf : f_inf) : f_nan; + else if(Wm == 3 and (x == 0x7F or x == 0xFF)) + return f_nan; } typename std::conditional::type retval; - const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT // subnormal input if(exponent == 0) { // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + __builtin_clz(mantissa) - (32 - wm); - mantissa <<= sh; + int sh = 1 + __builtin_clz(mantissa) - (32 - Wm); + mantissa <<= sh; // NOLINT exponent += 1 - sh; - mantissa &= ((1 << wm) - 1); + mantissa &= ((1 << Wm) - 1); // NOLINT } exponent += exp_low_cutoff - 1; - mantissa <<= wmo - wm; + mantissa <<= wmo - Wm; // NOLINT - // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + // subnormal output (occurs when T=half, We=5, negative_zero_nan=true) if(exponent <= 0) { - mantissa |= 1 << wmo; - mantissa >>= 1 - exponent; + mantissa |= 1 << wmo; // NOLINT + mantissa >>= 1 - exponent; // NOLINT exponent = 0; } if(sizeof(T) == 2) - retval = (sign << 15) | (exponent << 10) | mantissa; + retval = (sign << 15) | (exponent << 10) | mantissa; // NOLINT else - retval = (sign << 31) | (exponent << 23) | mantissa; + retval = (sign << 31) | (exponent << 23) | mantissa; // NOLINT return migraphx::bit_cast(retval); } diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index b4f0881f8d3..4495e21ecac 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type) case shape::uint8_type: return rocblas_datatype_u8_r; case shape::int32_type: return rocblas_datatype_i32_r; case shape::uint32_type: return rocblas_datatype_u32_r; + case shape::fp8e4m3fnuz_type: case shape::tuple_type: case shape::bool_type: case shape::uint16_type: diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp index bef04c01d2c..eb3dd6bb959 100644 --- a/test/fp8e4m3fn.cpp +++ b/test/fp8e4m3fn.cpp @@ -134,6 +134,15 @@ TEST_CASE(test_negative_zero) EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); } +TEST_CASE(test_pos_zero_eq_neg_zero) +{ + float nzero = -0.0; + float pzero = 0.0; + migraphx::fp8::fp8e5m2 fp8_nzero(nzero); + migraphx::fp8::fp8e5m2 fp8_pzero(pzero); + EXPECT(fp8_nzero == fp8_pzero); +} + TEST_CASE(test_nan_1) { float fnan = std::numeric_limits::quiet_NaN(); diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index 8e2c1a68362..03a4adef13d 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -331,6 +331,15 @@ TEST_CASE(test_negative_zero) EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); } +TEST_CASE(test_pos_zero_eq_neg_zero) +{ + float nzero = -0.0; + float pzero = 0.0; + migraphx::fp8::fp8e5m2 fp8_nzero(nzero); + migraphx::fp8::fp8e5m2 fp8_pzero(pzero); + EXPECT(fp8_nzero == fp8_pzero); +} + TEST_CASE(test_nan_1) { float fnan = std::numeric_limits::quiet_NaN(); From 9ee0418d5fc03e8b9e61eeb44beda916d5d41062 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 14 Nov 2023 00:31:30 +0000 Subject: [PATCH 012/115] Fix test failure --- src/include/migraphx/float8.hpp | 38 ++++++++++++++++----------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp index 3303b09d3fb..d8afcb75d20 100644 --- a/src/include/migraphx/float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -350,25 +350,25 @@ class numeric_limits // define numeric limits for the new data type // NOLINTBEGIN namespace std { -#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ - inline bool isfinite(T x) { return x.is_inf(); } \ - inline bool isnan(T x) { return x.is_nan(); } \ - template <> \ - class numeric_limits : public migraphx::fp8::numeric_limits \ - { \ - }; \ - template \ - struct common_type : std::common_type \ - { \ - }; \ - template \ - struct common_type : std::common_type \ - { \ - }; \ - template <> \ - struct common_type \ - { \ - using type = T; \ +#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ + inline bool isfinite(T x) { return not x.is_inf() and not x.is_nan(); } \ + inline bool isnan(T x) { return x.is_nan(); } \ + template <> \ + class numeric_limits : public migraphx::fp8::numeric_limits \ + { \ + }; \ + template \ + struct common_type : std::common_type \ + { \ + }; \ + template \ + struct common_type : std::common_type \ + { \ + }; \ + template <> \ + struct common_type \ + { \ + using type = T; \ }; MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn) From 355e4f6f19a7e37b131033d5fd2cce064401f8ca Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 14 Nov 2023 00:42:29 +0000 Subject: [PATCH 013/115] fix isfinite --- src/include/migraphx/float8.hpp | 10 ++++------ test/fp8e4m3fn.cpp | 14 ++++++++++++++ test/fp8e4m3fnuz.cpp | 14 ++++++++++++++ test/fp8e5m2.cpp | 13 +++++++++++++ test/fp8e5m2fnuz.cpp | 14 ++++++++++++++ 5 files changed, 59 insertions(+), 6 deletions(-) diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp index d8afcb75d20..8461fe60cb2 100644 --- a/src/include/migraphx/float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -273,9 +273,8 @@ using fp8e5m2fnuz = float8; template <> class numeric_limits { - static constexpr bool has_infinity = false; - public: + static constexpr bool has_infinity = false; static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); } // NOLINTNEXTLINE static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } @@ -290,9 +289,8 @@ class numeric_limits template <> class numeric_limits { - static constexpr bool has_infinity = false; - public: + static constexpr bool has_infinity = false; static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); } // NOLINTNEXTLINE static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); } @@ -307,9 +305,8 @@ class numeric_limits template <> class numeric_limits { - static constexpr bool has_infinity = false; - public: + static constexpr bool has_infinity = false; static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } static constexpr fp8e5m2fnuz quiet_NaN() // NOLINT @@ -329,6 +326,7 @@ template <> class numeric_limits { public: + static constexpr bool has_infinity = true; static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); } // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } // NOLINT diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp index eb3dd6bb959..5a73abfc285 100644 --- a/test/fp8e4m3fn.cpp +++ b/test/fp8e4m3fn.cpp @@ -212,4 +212,18 @@ TEST_CASE(test_max_eq_lowest) EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), -1 * std::numeric_limits::max())); } + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(0.0))); + EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(-0.0))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e4m3fn(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_no_infinity) +{ + EXPECT(not bool{std::numeric_limits::has_infinity}); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp index 11ab4ada335..8f54c131bf5 100644 --- a/test/fp8e4m3fnuz.cpp +++ b/test/fp8e4m3fnuz.cpp @@ -227,4 +227,18 @@ TEST_CASE(test_max_eq_lowest) EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), -1 * std::numeric_limits::max())); } + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(0.0))); + EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(-0.0))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e4m3fnuz(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_no_infinity) +{ + EXPECT(not bool{std::numeric_limits::has_infinity}); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index 03a4adef13d..e43770837e2 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -409,4 +409,17 @@ TEST_CASE(test_max_eq_lowest) EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), -1 * std::numeric_limits::max())); } + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(0.0))); + EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(-0.0))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e5m2(std::numeric_limits::infinity()))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e5m2(-1.0 * std::numeric_limits::infinity()))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e5m2(std::numeric_limits::quiet_NaN()))); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2fnuz.cpp b/test/fp8e5m2fnuz.cpp index 06f422672ba..492776c8882 100644 --- a/test/fp8e5m2fnuz.cpp +++ b/test/fp8e5m2fnuz.cpp @@ -397,4 +397,18 @@ TEST_CASE(test_max_eq_lowest) EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), -1 * std::numeric_limits::max())); } + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::fp8::fp8e5m2fnuz(0.0))); + EXPECT(std::isfinite(migraphx::fp8::fp8e5m2fnuz(-0.0))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e5m2fnuz(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_no_infinity) +{ + EXPECT(not bool{std::numeric_limits::has_infinity}); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 6aec70330931a92c95c7092532c13067d7c27375 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 14 Nov 2023 14:34:10 +0000 Subject: [PATCH 014/115] fix test for neg inf --- test/fp8e5m2.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index e43770837e2..0efc69a1a85 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -414,12 +414,14 @@ TEST_CASE(test_isfinite) { EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(0.0))); EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(-0.0))); - EXPECT(not std::isfinite( - migraphx::fp8::fp8e5m2(std::numeric_limits::infinity()))); - EXPECT(not std::isfinite( - migraphx::fp8::fp8e5m2(-1.0 * std::numeric_limits::infinity()))); EXPECT(not std::isfinite( migraphx::fp8::fp8e5m2(std::numeric_limits::quiet_NaN()))); + EXPECT(not std::isfinite(std::numeric_limits::infinity())); + // -1.0 * inf is float(-inf) which with clipping/saturation gets converted into fp8::lowest() + EXPECT(std::isfinite( + migraphx::fp8::fp8e5m2(-1.0 * std::numeric_limits::infinity()))); + // fp8(-neg_inf) + EXPECT(not std::isfinite(migraphx::fp8::fp8e5m2(0xFC, migraphx::fp8::fp8e5m2::from_bits()))); } int main(int argc, const char* argv[]) { test::run(argc, argv); } From 12aac372fa93ceca9913e411ec1ca6841ed2d832 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 14 Nov 2023 14:36:48 +0000 Subject: [PATCH 015/115] fix warning --- test/fp8e5m2.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index 0efc69a1a85..7382f9b2329 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -32,7 +32,7 @@ float fp8e5m2_to_fp32_value(uint8_t input) { - constexpr std::array e4m3fnuz_lut = {{ + constexpr std::array e4m3fnuz_lut = { 0.0, 1.52587890625e-05, 3.0517578125e-05, @@ -285,11 +285,10 @@ float fp8e5m2_to_fp32_value(uint8_t input) -40960.0, -49152.0, -57344.0, - -1.0 * std::numeric_limits::infinity(), + -1.0f * std::numeric_limits::infinity(), std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), - } }; @@ -420,7 +419,6 @@ TEST_CASE(test_isfinite) // -1.0 * inf is float(-inf) which with clipping/saturation gets converted into fp8::lowest() EXPECT(std::isfinite( migraphx::fp8::fp8e5m2(-1.0 * std::numeric_limits::infinity()))); - // fp8(-neg_inf) EXPECT(not std::isfinite(migraphx::fp8::fp8e5m2(0xFC, migraphx::fp8::fp8e5m2::from_bits()))); } From 60092324e70f81adf992edc94b8ad845765e5e8b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 14 Nov 2023 15:27:57 +0000 Subject: [PATCH 016/115] add tests --- src/include/migraphx/float8.hpp | 97 ++++++++++++++++++++++----------- test/fp8e4m3fn.cpp | 22 ++++++++ test/fp8e4m3fnuz.cpp | 22 ++++++++ test/fp8e5m2.cpp | 22 ++++++++ test/fp8e5m2fnuz.cpp | 22 ++++++++ 5 files changed, 154 insertions(+), 31 deletions(-) diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp index 8461fe60cb2..339d5acc012 100644 --- a/src/include/migraphx/float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -227,49 +227,84 @@ struct float8 } }; +// https://onnx.ai/onnx/technical/float8.html +using fp8e4m3fn = float8; +using fp8e5m2 = float8; +using fp8e4m3fnuz = float8; +using fp8e5m2fnuz = float8; +/* +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \ + inline constexpr U operator binary_op(const T& lhs, const T& rhs) \ + { \ + return U(static_cast(lhs) binary_op static_cast(rhs)); \ + } + +// TODO: these should return floats for binary ops +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_BINARY_OP_GEN_FOR(T) \ + MIGRAPHX_FP8_BINARY_OP(*, T, T) \ + MIGRAPHX_FP8_BINARY_OP(-, T, T) \ + MIGRAPHX_FP8_BINARY_OP(/, T, T) \ + MIGRAPHX_FP8_BINARY_OP(+, T, T) \ + MIGRAPHX_FP8_BINARY_OP(==, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(>, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(<, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(!=, T, bool) + +MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2) +MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fn) +MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2fnuz) +MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fnuz) +*/ + // Special operator overloading -template -inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8& rhs) +inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fnuz& rhs) { return os << static_cast(rhs); } -// NOLINTNEXTLINE -#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ - template \ - inline constexpr U operator binary_op(const migraphx::fp8::float8& lhs, \ - const migraphx::fp8::float8& rhs) \ - { \ - return U(static_cast(lhs) binary_op static_cast(rhs)); \ - } +inline fp8e4m3fnuz fabs(fp8e4m3fnuz v) +{ + v.data = v.data & 0x7f; // NOLINT + return v; +} +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fn& rhs) +{ + return os << static_cast(rhs); +} -// TODO: these should return floats -MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8::float8) -MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8::float8) -MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8::float8) -MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8::float8) -// TODO: Comparison ops shouldn't convert to float, need to check if need to take care of rounding -// effects. -MIGRAPHX_FP8_BINARY_OP(==, bool) -MIGRAPHX_FP8_BINARY_OP(>=, bool) -MIGRAPHX_FP8_BINARY_OP(<=, bool) -MIGRAPHX_FP8_BINARY_OP(>, bool) -MIGRAPHX_FP8_BINARY_OP(<, bool) -MIGRAPHX_FP8_BINARY_OP(!=, bool) - -template -inline migraphx::fp8::float8 fabs(migraphx::fp8::float8 v) +inline fp8e4m3fn fabs(fp8e4m3fn v) { v.data = v.data & 0x7f; // NOLINT return v; } -// https://onnx.ai/onnx/technical/float8.html -using fp8e4m3fn = float8; -using fp8e5m2 = float8; -using fp8e4m3fnuz = float8; -using fp8e5m2fnuz = float8; +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const fp8e5m2fnuz& rhs) +{ + return os << static_cast(rhs); +} +inline fp8e5m2fnuz fabs(fp8e5m2fnuz v) +{ + v.data = v.data & 0x7f; // NOLINT + return v; +} +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const fp8e5m2& rhs) +{ + return os << static_cast(rhs); +} + +inline fp8e5m2 fabs(fp8e5m2 v) +{ + v.data = v.data & 0x7f; // NOLINT + return v; +} template <> class numeric_limits { diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp index 5a73abfc285..65ccc1724c3 100644 --- a/test/fp8e4m3fn.cpp +++ b/test/fp8e4m3fn.cpp @@ -226,4 +226,26 @@ TEST_CASE(test_no_infinity) EXPECT(not bool{std::numeric_limits::has_infinity}); } +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::fp8::fp8e5m2(-1.0); + auto b = migraphx::fp8::fp8e5m2(1.0); + auto c = migraphx::fp8::fp8e5m2(0.0); + auto d = migraphx::fp8::fp8e5m2(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::fp8::fp8e5m2(10.0); + auto f = migraphx::fp8::fp8e5m2(-10.0); + EXPECT(bool{e > f}); + EXPECT(bool{f < e}); + EXPECT(bool(f <= e)); + EXPECT(bool{e >= f}); + EXPECT(bool{e <= e}); + EXPECT(bool{f >= f}); + EXPECT(not migraphx::float_equal(f, e)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp index 8f54c131bf5..c69ac3156d6 100644 --- a/test/fp8e4m3fnuz.cpp +++ b/test/fp8e4m3fnuz.cpp @@ -241,4 +241,26 @@ TEST_CASE(test_no_infinity) EXPECT(not bool{std::numeric_limits::has_infinity}); } +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::fp8::fp8e5m2(-1.0); + auto b = migraphx::fp8::fp8e5m2(1.0); + auto c = migraphx::fp8::fp8e5m2(0.0); + auto d = migraphx::fp8::fp8e5m2(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::fp8::fp8e5m2(10.0); + auto f = migraphx::fp8::fp8e5m2(-10.0); + EXPECT(bool{e > f}); + EXPECT(bool{f < e}); + EXPECT(bool(f <= e)); + EXPECT(bool{e >= f}); + EXPECT(bool{e <= e}); + EXPECT(bool{f >= f}); + EXPECT(not migraphx::float_equal(f, e)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index 7382f9b2329..5875bb0ad44 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -422,4 +422,26 @@ TEST_CASE(test_isfinite) EXPECT(not std::isfinite(migraphx::fp8::fp8e5m2(0xFC, migraphx::fp8::fp8e5m2::from_bits()))); } +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::fp8::fp8e5m2(-1.0); + auto b = migraphx::fp8::fp8e5m2(1.0); + auto c = migraphx::fp8::fp8e5m2(0.0); + auto d = migraphx::fp8::fp8e5m2(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::fp8::fp8e5m2(10.0); + auto f = migraphx::fp8::fp8e5m2(-10.0); + EXPECT(bool{e > f}); + EXPECT(bool{f < e}); + EXPECT(bool(f <= e)); + EXPECT(bool{e >= f}); + EXPECT(bool{e <= e}); + EXPECT(bool{f >= f}); + EXPECT(not migraphx::float_equal(f, e)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2fnuz.cpp b/test/fp8e5m2fnuz.cpp index 492776c8882..864a6c36cec 100644 --- a/test/fp8e5m2fnuz.cpp +++ b/test/fp8e5m2fnuz.cpp @@ -411,4 +411,26 @@ TEST_CASE(test_no_infinity) EXPECT(not bool{std::numeric_limits::has_infinity}); } +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::fp8::fp8e5m2(-1.0); + auto b = migraphx::fp8::fp8e5m2(1.0); + auto c = migraphx::fp8::fp8e5m2(0.0); + auto d = migraphx::fp8::fp8e5m2(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::fp8::fp8e5m2(10.0); + auto f = migraphx::fp8::fp8e5m2(-10.0); + EXPECT(bool{e > f}); + EXPECT(bool{f < e}); + EXPECT(bool(f <= e)); + EXPECT(bool{e >= f}); + EXPECT(bool{e <= e}); + EXPECT(bool{f >= f}); + EXPECT(not migraphx::float_equal(f, e)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 03f71398ce86e2746566933383c580c0203be594 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 14 Nov 2023 15:34:08 +0000 Subject: [PATCH 017/115] Fix tests --- test/fp8e4m3fn.cpp | 12 ++++++------ test/fp8e4m3fnuz.cpp | 12 ++++++------ test/fp8e5m2fnuz.cpp | 12 ++++++------ 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp index 65ccc1724c3..d29c803981a 100644 --- a/test/fp8e4m3fn.cpp +++ b/test/fp8e4m3fn.cpp @@ -228,17 +228,17 @@ TEST_CASE(test_no_infinity) TEST_CASE(test_binary_ops) { - auto a = migraphx::fp8::fp8e5m2(-1.0); - auto b = migraphx::fp8::fp8e5m2(1.0); - auto c = migraphx::fp8::fp8e5m2(0.0); - auto d = migraphx::fp8::fp8e5m2(-0.0); + auto a = migraphx::fp8::fp8e4m3fn(-1.0); + auto b = migraphx::fp8::fp8e4m3fn(1.0); + auto c = migraphx::fp8::fp8e4m3fn(0.0); + auto d = migraphx::fp8::fp8e4m3fn(-0.0); EXPECT(migraphx::float_equal((c + d), c)); EXPECT(migraphx::float_equal((c + d), d)); EXPECT(migraphx::float_equal((a + b), c)); EXPECT(migraphx::float_equal((a + b), d)); - auto e = migraphx::fp8::fp8e5m2(10.0); - auto f = migraphx::fp8::fp8e5m2(-10.0); + auto e = migraphx::fp8::fp8e4m3fn(10.0); + auto f = migraphx::fp8::fp8e4m3fn(-10.0); EXPECT(bool{e > f}); EXPECT(bool{f < e}); EXPECT(bool(f <= e)); diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp index c69ac3156d6..c38b7b85857 100644 --- a/test/fp8e4m3fnuz.cpp +++ b/test/fp8e4m3fnuz.cpp @@ -243,17 +243,17 @@ TEST_CASE(test_no_infinity) TEST_CASE(test_binary_ops) { - auto a = migraphx::fp8::fp8e5m2(-1.0); - auto b = migraphx::fp8::fp8e5m2(1.0); - auto c = migraphx::fp8::fp8e5m2(0.0); - auto d = migraphx::fp8::fp8e5m2(-0.0); + auto a = migraphx::fp8::fp8e4m3fnuz(-1.0); + auto b = migraphx::fp8::fp8e4m3fnuz(1.0); + auto c = migraphx::fp8::fp8e4m3fnuz(0.0); + auto d = migraphx::fp8::fp8e4m3fnuz(-0.0); EXPECT(migraphx::float_equal((c + d), c)); EXPECT(migraphx::float_equal((c + d), d)); EXPECT(migraphx::float_equal((a + b), c)); EXPECT(migraphx::float_equal((a + b), d)); - auto e = migraphx::fp8::fp8e5m2(10.0); - auto f = migraphx::fp8::fp8e5m2(-10.0); + auto e = migraphx::fp8::fp8e4m3fnuz(10.0); + auto f = migraphx::fp8::fp8e4m3fnuz(-10.0); EXPECT(bool{e > f}); EXPECT(bool{f < e}); EXPECT(bool(f <= e)); diff --git a/test/fp8e5m2fnuz.cpp b/test/fp8e5m2fnuz.cpp index 864a6c36cec..c8ba868d4cd 100644 --- a/test/fp8e5m2fnuz.cpp +++ b/test/fp8e5m2fnuz.cpp @@ -413,17 +413,17 @@ TEST_CASE(test_no_infinity) TEST_CASE(test_binary_ops) { - auto a = migraphx::fp8::fp8e5m2(-1.0); - auto b = migraphx::fp8::fp8e5m2(1.0); - auto c = migraphx::fp8::fp8e5m2(0.0); - auto d = migraphx::fp8::fp8e5m2(-0.0); + auto a = migraphx::fp8::fp8e5m2fnuz(-1.0); + auto b = migraphx::fp8::fp8e5m2fnuz(1.0); + auto c = migraphx::fp8::fp8e5m2fnuz(0.0); + auto d = migraphx::fp8::fp8e5m2fnuz(-0.0); EXPECT(migraphx::float_equal((c + d), c)); EXPECT(migraphx::float_equal((c + d), d)); EXPECT(migraphx::float_equal((a + b), c)); EXPECT(migraphx::float_equal((a + b), d)); - auto e = migraphx::fp8::fp8e5m2(10.0); - auto f = migraphx::fp8::fp8e5m2(-10.0); + auto e = migraphx::fp8::fp8e5m2fnuz(10.0); + auto f = migraphx::fp8::fp8e5m2fnuz(-10.0); EXPECT(bool{e > f}); EXPECT(bool{f < e}); EXPECT(bool(f <= e)); From 1e220c00a20d019e355d5bfb9a6abfc5f6faebf2 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 14 Nov 2023 15:57:21 +0000 Subject: [PATCH 018/115] add stringstream tests --- src/include/migraphx/float8.hpp | 9 +++++---- test/fp8e4m3fn.cpp | 19 +++++++++++++++++++ test/fp8e4m3fnuz.cpp | 19 +++++++++++++++++++ test/fp8e5m2.cpp | 20 ++++++++++++++++++++ test/fp8e5m2fnuz.cpp | 18 ++++++++++++++++++ 5 files changed, 81 insertions(+), 4 deletions(-) diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp index 339d5acc012..805afe8e57c 100644 --- a/src/include/migraphx/float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -268,9 +268,10 @@ inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fnuz& rhs) inline fp8e4m3fnuz fabs(fp8e4m3fnuz v) { - v.data = v.data & 0x7f; // NOLINT + v.data = v.data & 0x7F; // NOLINT return v; } + // Special operator overloading inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fn& rhs) { @@ -279,7 +280,7 @@ inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fn& rhs) inline fp8e4m3fn fabs(fp8e4m3fn v) { - v.data = v.data & 0x7f; // NOLINT + v.data = v.data & 0x7F; // NOLINT return v; } @@ -291,7 +292,7 @@ inline std::ostream& operator<<(std::ostream& os, const fp8e5m2fnuz& rhs) inline fp8e5m2fnuz fabs(fp8e5m2fnuz v) { - v.data = v.data & 0x7f; // NOLINT + v.data = v.data & 0x7F; // NOLINT return v; } // Special operator overloading @@ -302,7 +303,7 @@ inline std::ostream& operator<<(std::ostream& os, const fp8e5m2& rhs) inline fp8e5m2 fabs(fp8e5m2 v) { - v.data = v.data & 0x7f; // NOLINT + v.data = v.data & 0x7F; // NOLINT return v; } template <> diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp index d29c803981a..874cd4684d1 100644 --- a/test/fp8e4m3fn.cpp +++ b/test/fp8e4m3fn.cpp @@ -248,4 +248,23 @@ TEST_CASE(test_binary_ops) EXPECT(not migraphx::float_equal(f, e)); } +TEST_CASE(test_fabs) +{ + auto a = migraphx::fp8::fp8e4m3fn(-1.0); + auto b = migraphx::fp8::fp8e4m3fn(1.0); + EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a))); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::fp8::fp8e4m3fn(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp index c38b7b85857..b99c09ba1bf 100644 --- a/test/fp8e4m3fnuz.cpp +++ b/test/fp8e4m3fnuz.cpp @@ -263,4 +263,23 @@ TEST_CASE(test_binary_ops) EXPECT(not migraphx::float_equal(f, e)); } +TEST_CASE(test_fabs) +{ + auto a = migraphx::fp8::fp8e4m3fnuz(-1.0); + auto b = migraphx::fp8::fp8e4m3fnuz(1.0); + EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a))); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::fp8::fp8e4m3fnuz(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index 5875bb0ad44..e9f5fd3f5f9 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -29,6 +29,7 @@ #include "test.hpp" #include +#include float fp8e5m2_to_fp32_value(uint8_t input) { @@ -444,4 +445,23 @@ TEST_CASE(test_binary_ops) EXPECT(not migraphx::float_equal(f, e)); } +TEST_CASE(test_fabs) +{ + auto a = migraphx::fp8::fp8e5m2(-1.0); + auto b = migraphx::fp8::fp8e5m2(1.0); + EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a))); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::fp8::fp8e5m2(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2fnuz.cpp b/test/fp8e5m2fnuz.cpp index c8ba868d4cd..157b0d28600 100644 --- a/test/fp8e5m2fnuz.cpp +++ b/test/fp8e5m2fnuz.cpp @@ -433,4 +433,22 @@ TEST_CASE(test_binary_ops) EXPECT(not migraphx::float_equal(f, e)); } +TEST_CASE(test_fabs) +{ + auto a = migraphx::fp8::fp8e5m2fnuz(-1.0); + auto b = migraphx::fp8::fp8e5m2fnuz(1.0); + EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a))); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::fp8::fp8e5m2fnuz(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} int main(int argc, const char* argv[]) { test::run(argc, argv); } From a83e9dc6ff18a20989a42cd7f01e513455b504cf Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 15 Nov 2023 16:24:39 +0000 Subject: [PATCH 019/115] Remove clang diagnostics --- src/include/migraphx/float8.hpp | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp index 805afe8e57c..b2d6fedc68c 100644 --- a/src/include/migraphx/float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -22,12 +22,6 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP #define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wold-style-cast" -#pragma clang diagnostic ignored "-Wfloat-equal" -#pragma clang diagnostic ignored "-Wc++20-extensions" -#endif // __clang__ // We are clipping/saturation in down conversion by default. Unclipped version is not tested and // shouldn't be used without having enough tests. @@ -52,7 +46,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace fp8 { -enum class migraphx_f8_rounding_mode +enum class rounding_mode { standard, // standard rounding is doing RNE -- round to nearest even stochastic @@ -82,21 +76,21 @@ struct float8 explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} - explicit constexpr float8(float v, - migraphx::fp8::migraphx_f8_rounding_mode rm = - migraphx::fp8::migraphx_f8_rounding_mode::standard, - uint32_t rng = 0) + explicit constexpr float8( + float v, + migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, + uint32_t rng = 0) { if constexpr(T == migraphx::fp8::f8_type::fp8) { #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx::fp8::impl:: cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( - v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng); + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); #else // MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx::fp8::impl:: cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( - v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng); + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); #endif // MIGRAPHX_F8_DOWNCAST_CLIPPING } else @@ -104,11 +98,11 @@ struct float8 #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx::fp8::impl:: cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( - v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng); + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); #else // MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx::fp8::impl:: cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( - v, (rm == migraphx::fp8::migraphx_f8_rounding_mode::stochastic), rng); + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); #endif // rocblas_F8_downcast_clipping} } } @@ -412,7 +406,4 @@ MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz) } // namespace std // NOLINTEND // ================================================================================================= -#if defined(__clang__) -#pragma clang diagnostic pop -#endif #endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP From 26956f1dd4e019c6569df1917717415edeb06ea7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 15 Nov 2023 16:28:42 +0000 Subject: [PATCH 020/115] Remove NOLINTS --- src/include/migraphx/float8_impl.hpp | 65 ++++++++++++++-------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/src/include/migraphx/float8_impl.hpp b/src/include/migraphx/float8_impl.hpp index 8139e3de482..b8895ee5c86 100644 --- a/src/include/migraphx/float8_impl.hpp +++ b/src/include/migraphx/float8_impl.hpp @@ -55,23 +55,23 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) uint32_t sign = 0; if constexpr(sizeof(T) == 4) { - head = x & 0xFF800000; // NOLINT - mantissa = x & 0x7FFFFF; // NOLINT - exponent = (head >> 23) & 0xFF; // NOLINT - sign = head >> 31; // NOLINT + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; bias = 127; } else { - head = x & 0xFC00; // NOLINT - mantissa = x & 0x3FF; // NOLINT - exponent = (head >> 10) & 0x1F; // NOLINT - sign = head >> 15; // NOLINT + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; bias = 15; } - uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm); // NOLINT - uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1)); // NOLINT + uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm); + uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1)); // Calcualte maximum singed value FLT_MAX, FLT_MIN uint32_t signed_max = signed_all_ones; @@ -81,8 +81,8 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) // Deal with inf and NaNs if(NegativeZeroNan) // For the FNUZ cases, it is simple just return NaNs { - if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or // NOLINT - (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) // NOLINT + if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or + (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) return 0x80; } else @@ -91,10 +91,10 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) uint32_t nan_mantissa = 1; for(auto i = 1; i < Wm; ++i) { - nan_mantissa |= (nan_mantissa << 1); // NOLINT + nan_mantissa |= (nan_mantissa << 1); } - if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or // NOLINT - (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) // NOLINT + if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or + (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) { // infinity if(mantissa == 0) @@ -124,7 +124,7 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) exponent and mantissa again*/ // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits - const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0); // NOLINT + const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0); const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal /* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) f8_exponent is the converted f8 exponent with bias encoding @@ -166,9 +166,9 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) } mantissa += (1u << mfmt); // Add the implicit 1 into mantissa } - // NOLINTNEXTLINE + bool midpoint = (mantissa & ((1 << (mfmt - Wm + exponent_diff)) - 1)) == - (1 << (mfmt - Wm + exponent_diff - 1)); // NOLINT + (1 << (mfmt - Wm + exponent_diff - 1)); /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift right as shift right could rip off some residual part and make something not midpoint look like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than @@ -176,36 +176,35 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) */ if(exponent_diff > 0) - mantissa >>= exponent_diff; // NOLINT + mantissa >>= exponent_diff; else if(exponent_diff == -1) - mantissa <<= -exponent_diff; // NOLINT - bool implicit_one = mantissa & (1 << mfmt); // NOLINT + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << mfmt); // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); // Now we have the exponent and mantissa adjusted - uint32_t drop_mask = (1u << (mfmt - Wm)) - 1; // NOLINT + uint32_t drop_mask = (1u << (mfmt - Wm)) - 1; bool odd = mantissa & (1u << (mfmt - Wm)); // if the least significant bit that is not truncated is 1 - mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & // NOLINT - drop_mask; // NOLINT + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; // Now we deal with overflow - if(f8_exponent == 0 and ((1 << mfmt) & mantissa)) // NOLINT + if(f8_exponent == 0 and ((1 << mfmt) & mantissa)) { - f8_exponent = 1; // denormal overflow to become normal, promote exponent + f8_exponent = 1; // denormal overflow to become normal, promote exponent } - else if((1 << (mfmt + 1)) & mantissa) // NOLINT + else if((1 << (mfmt + 1)) & mantissa) { - mantissa >>= 1; // NOLINT + mantissa >>= 1; f8_exponent++; } - mantissa >>= (mfmt - Wm); // NOLINT + mantissa >>= (mfmt - Wm); // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << We) - (NegativeZeroNan ? 1 : 2); // NOLINT + const int max_exp = (1 << We) - (NegativeZeroNan ? 1 : 2); if(f8_exponent > max_exp) { if(Clip) @@ -221,9 +220,9 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) } if(f8_exponent == 0 and mantissa == 0) - return NegativeZeroNan ? 0 : (sign << 7); // NOLINT - mantissa &= (1 << Wm) - 1; // NOLINT - return (sign << 7) | (f8_exponent << Wm) | mantissa; // NOLINT + return NegativeZeroNan ? 0 : (sign << 7); + mantissa &= (1 << Wm) - 1; + return (sign << 7) | (f8_exponent << Wm) | mantissa; } // NOLINTEND From 269ce6d1230b44986c4b171aa2b2cad28080eb1f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 16 Nov 2023 01:09:33 +0000 Subject: [PATCH 021/115] Bugfixes and additional tests --- src/include/migraphx/float8_impl.hpp | 17 ++++++++-------- src/py/migraphx_py.cpp | 2 +- test/fp8e4m3fn.cpp | 23 ++++++++++++++++++++- test/fp8e4m3fnuz.cpp | 30 +++++++++++++++++++++++++++- test/fp8e5m2.cpp | 13 +++++++++++- test/fp8e5m2fnuz.cpp | 25 ++++++++++++++++++++++- 6 files changed, 97 insertions(+), 13 deletions(-) diff --git a/src/include/migraphx/float8_impl.hpp b/src/include/migraphx/float8_impl.hpp index b8895ee5c86..2050662a740 100644 --- a/src/include/migraphx/float8_impl.hpp +++ b/src/include/migraphx/float8_impl.hpp @@ -134,15 +134,15 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) int f8_exponent = 0; int exponent_diff = 0; - if(exponent == 0) + if(exponent == 0 and mantissa != 0) { // fp32/fp16 is in denormal. /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal - has exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some - numbers in fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is + has exponent bias 15 while bf8 with FNUZ has exponent bias 16. It means that there are some + numbers in fp16 denormal but they are bf8 (FNUZ) normals - smallest bf8 (FNUZ) normal is 2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 - are bf8 (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ - act_exponent = exponent - bias + 1; + are bf8 (FNUZ) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = 1 - bias; exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal } @@ -152,10 +152,10 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) if(act_exponent <= f8_denormal_act_exponent) { /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. - For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16 actual exponent is -7, it is actually larger due to the implict 1, Therefore it needs to be adjust to -6 and mantissa shift right by 1. - So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */ exponent_diff = f8_denormal_act_exponent - act_exponent; } else @@ -204,7 +204,8 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) mantissa >>= (mfmt - Wm); // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << We) - (NegativeZeroNan ? 1 : 2); + // for e5m2 case, max_exp is 14, since exp = 15 is reserved for Infs and Nans + const int max_exp = (1 << We) - ((NegativeZeroNan or Wm == 3) ? 1 : 2); if(f8_exponent > max_exp) { if(Clip) diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 91af6cf9ded..a12e69d7e80 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -150,7 +150,7 @@ struct npy_format_descriptor static std::string format() { // following: https://docs.python.org/3/library/struct.html#format-characters - return "z"; + return "B"; } static constexpr auto name() { return _("fp8e4m3fnuz"); } }; diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp index 874cd4684d1..0fc0ca90c9d 100644 --- a/test/fp8e4m3fn.cpp +++ b/test/fp8e4m3fn.cpp @@ -117,6 +117,27 @@ TEST_CASE(test_fp8_cast_to_float) })}); } +TEST_CASE(test_fp8_cast_from_float) +{ + std::unordered_map test_vals = { + {{512, 0x7e}, {-512, 0xfe}, {448, 0x7e}, {-448, 0xfe}, + {256, 0x78}, {-256, 0xf8}, {240, 0x77}, {-240, 0xf7}, + {1e-07, 0x0}, {1e+07, 0x7e}, {1, 0x38}, {-1, 0xb8}, + {0.1, 0x1d}, {0.11, 0x1e}, {0.111, 0x1e}, {0.1111, 0x1e}, + {-0.1, 0x9d}, {-0.11, 0x9e}, {-0.111, 0x9e}, {-0.1111, 0x9e}, + {0.2, 0x25}, {2, 0x40}, {20, 0x5a}, {200, 0x74}, + {-0.2, 0xa5}, {-2, 0xc0}, {-20, 0xda}, {-200, 0xf4}, + {0.5, 0x30}, {-0.5, 0xb0}, {1.17549e-38, 0x0}, {1.4013e-45, 0x0}, + {0.0078125, 0x4}, {-0.0078125, 0x84}, {0.000976562, 0x0}, {-0.000976562, 0x80}, + {0.000488281, 0x0}, {-0.000488281, 0x80}}}; + + EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) { + return migraphx::float_equal( + migraphx::fp8::fp8e4m3fn(sample.first), + migraphx::fp8::fp8e4m3fn(sample.second, migraphx::fp8::fp8e4m3fn::from_bits())); + })}); +} + TEST_CASE(test_positive_zero) { float zero = 0.0; @@ -241,7 +262,7 @@ TEST_CASE(test_binary_ops) auto f = migraphx::fp8::fp8e4m3fn(-10.0); EXPECT(bool{e > f}); EXPECT(bool{f < e}); - EXPECT(bool(f <= e)); + EXPECT(bool{f <= e}); EXPECT(bool{e >= f}); EXPECT(bool{e <= e}); EXPECT(bool{f >= f}); diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp index b99c09ba1bf..e86cf8d76a1 100644 --- a/test/fp8e4m3fnuz.cpp +++ b/test/fp8e4m3fnuz.cpp @@ -138,6 +138,34 @@ TEST_CASE(test_fp8_cast_to_float) })}); } +TEST_CASE(test_fp8_cast_from_float) +{ + std::unordered_map test_vals = {{256, 0x7f}, {-256, 0xff}, + {240, 0x7f}, {-240, 0xff}, + {1e-07, 0x0}, {1e+07, 0x7f}, + {1, 0x40}, {-1, 0xc0}, + {0.1, 0x25}, {0.11, 0x26}, + {0.111, 0x26}, {0.1111, 0x26}, + {-0.1, 0xa5}, {-0.11, 0xa6}, + {-0.111, 0xa6}, {-0.1111, 0xa6}, + {0.2, 0x2d}, {2, 0x48}, + {20, 0x62}, {200, 0x7c}, + {-0.2, 0xad}, {-2, 0xc8}, + {-20, 0xe2}, {-200, 0xfc}, + {0.5, 0x38}, {-0.5, 0xb8}, + {1.17549e-38, 0x0}, {1.4013e-45, 0x0}, + {0.00390625, 0x4}, {-0.00390625, 0x84}, + {0.00195312, 0x2}, {-0.00195312, 0x82}, + {0.000976562, 0x1}, {-0.000976562, 0x81}, + {0.000488281, 0x0}, {-0.000488281, 0x0}}; + + EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) { + return migraphx::float_equal( + migraphx::fp8::fp8e4m3fnuz(sample.first), + migraphx::fp8::fp8e4m3fnuz(sample.second, migraphx::fp8::fp8e4m3fnuz::from_bits())); + })}); +} + TEST_CASE(test_positive_zero) { float zero = 0.0; @@ -256,7 +284,7 @@ TEST_CASE(test_binary_ops) auto f = migraphx::fp8::fp8e4m3fnuz(-10.0); EXPECT(bool{e > f}); EXPECT(bool{f < e}); - EXPECT(bool(f <= e)); + EXPECT(bool{f <= e}); EXPECT(bool{e >= f}); EXPECT(bool{e <= e}); EXPECT(bool{f >= f}); diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index e9f5fd3f5f9..83daf4040d1 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -314,6 +314,17 @@ TEST_CASE(test_fp8_cast_to_float) })}); } +TEST_CASE(test_fp8_cast_from_float) +{ + std::unordered_map test_vals = {}; + + EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) { + return migraphx::float_equal( + migraphx::fp8::fp8e5m2(sample.first), + migraphx::fp8::fp8e5m2(sample.second, migraphx::fp8::fp8e5m2::from_bits())); + })}); +} + TEST_CASE(test_positive_zero) { float zero = 0.0; @@ -438,7 +449,7 @@ TEST_CASE(test_binary_ops) auto f = migraphx::fp8::fp8e5m2(-10.0); EXPECT(bool{e > f}); EXPECT(bool{f < e}); - EXPECT(bool(f <= e)); + EXPECT(bool{f <= e}); EXPECT(bool{e >= f}); EXPECT(bool{e <= e}); EXPECT(bool{f >= f}); diff --git a/test/fp8e5m2fnuz.cpp b/test/fp8e5m2fnuz.cpp index 157b0d28600..14be8bc80d7 100644 --- a/test/fp8e5m2fnuz.cpp +++ b/test/fp8e5m2fnuz.cpp @@ -308,6 +308,29 @@ TEST_CASE(test_fp8_cast_to_float) })}); } +TEST_CASE(test_fp8_cast_from_float) +{ + std::unordered_map test_vals = { + {57344, 0x7f}, {-57344, 0xff}, {60000, 0x7f}, {-60000, 0xff}, + {448, 0x63}, {-448, 0xe3}, {256, 0x60}, {-256, 0xe0}, + {240, 0x60}, {-240, 0xe0}, {3.05176e-05, 0x4}, {-3.05176e-05, 0x84}, + {1.52588e-05, 0x2}, {-1.52588e-05, 0x82}, {7.62939e-06, 0x1}, {-7.62939e-06, 0x81}, + {3.81469e-06, 0x0}, {-3.81469e-06, 0x0}, {1e+07, 0x7f}, {1, 0x40}, + {-1, 0xc0}, {0.1, 0x32}, {0.11, 0x33}, {0.111, 0x33}, + {0.1111, 0x33}, {-0.1, 0xb2}, {-0.11, 0xb3}, {-0.111, 0xb3}, + {-0.1111, 0xb3}, {0.2, 0x36}, {2, 0x44}, {20, 0x51}, + {200, 0x5e}, {-0.2, 0xb6}, {-2, 0xc4}, {-20, 0xd1}, + {-200, 0xde}, {0.5, 0x3c}, {-0.5, 0xbc}, {1.17549e-38, 0x0}, + {1.4013e-45, 0x0}, + }; + + EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) { + return migraphx::float_equal( + migraphx::fp8::fp8e5m2fnuz(sample.first), + migraphx::fp8::fp8e5m2fnuz(sample.second, migraphx::fp8::fp8e5m2fnuz::from_bits())); + })}); +} + TEST_CASE(test_positive_zero) { float zero = 0.0; @@ -426,7 +449,7 @@ TEST_CASE(test_binary_ops) auto f = migraphx::fp8::fp8e5m2fnuz(-10.0); EXPECT(bool{e > f}); EXPECT(bool{f < e}); - EXPECT(bool(f <= e)); + EXPECT(bool{f <= e}); EXPECT(bool{e >= f}); EXPECT(bool{e <= e}); EXPECT(bool{f >= f}); From 6414ee38027a7dddebb3e13012d8ed1c22cdbc27 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 16 Nov 2023 14:02:55 +0000 Subject: [PATCH 022/115] Fix undoing --- test/fp8e5m2.cpp | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index 83daf4040d1..966aeb63d5c 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -316,7 +316,47 @@ TEST_CASE(test_fp8_cast_to_float) TEST_CASE(test_fp8_cast_from_float) { - std::unordered_map test_vals = {}; + std::unordered_map test_vals = { + {-60000, 0xfb}, + {-57344, 0xfb}, + {-448, 0xdf}, + {-256, 0xdc}, + {-240, 0xdc}, + {-200, 0xda}, + {-20, 0xcd}, + {-2, 0xc0}, + {-1, 0xbc}, + {-0.5, 0xb8}, + {-0.2, 0xb2}, + {-0.1111, 0xaf}, + {-0.111, 0xaf}, + {-0.11, 0xaf}, + {-0.1, 0xae}, + {6.10351e-05, 0x4}, + {-6.10351e-05, 0x84}, + {3.05176e-05, 0x2}, + {-3.05176e-05, 0x82}, + {1.52588e-05, 0x1}, + {-1.52588e-05, 0x81}, + {7.62939e-06, 0x0}, + {-7.62939e-06, 0x80}, + {0.1, 0x2e}, + {0.11, 0x2f}, + {0.111, 0x2f}, + {0.1111, 0x2f}, + {0.2, 0x32}, + {0.5, 0x38}, + {1, 0x3c}, + {2, 0x40}, + {20, 0x4d}, + {200, 0x5a}, + {240, 0x5c}, + {256, 0x5c}, + {448, 0x5f}, + {57344, 0x7b}, + {60000, 0x7b}, + {1e+07, 0x7b}, + }; EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) { return migraphx::float_equal( From cd26ada8d3326e223ff2de4049e96a6890e3f63e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 16 Nov 2023 16:02:32 +0000 Subject: [PATCH 023/115] Handle underflow case separately to avoid sanitization errors --- src/include/migraphx/float8_impl.hpp | 18 ++++++++++++------ src/py/migraphx_py.cpp | 3 ++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/include/migraphx/float8_impl.hpp b/src/include/migraphx/float8_impl.hpp index 2050662a740..f93e5a9a399 100644 --- a/src/include/migraphx/float8_impl.hpp +++ b/src/include/migraphx/float8_impl.hpp @@ -149,13 +149,19 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) else { // fp32/fp16 is normal with implicit 1 act_exponent = exponent - bias; - if(act_exponent <= f8_denormal_act_exponent) + /* + check if FP8 is underflowing to 0.0. Wm is added to check to allow FP8 to go into denorm + range. e.g. act_exponent for FP32/16 is -9 and e4m3fnuz has denorm_act exponent = -7 in + that case fp32/16 mantissa can be shifted right by two to make + exponent -7 and then it can be representable as e4m3fnuz denorm. So for fp32/fp16, exponent + -10 is the cut point to convert to e4m3fp8fnuz due to implicit 1 in mantissa. If fp32/16 + act_exponent is less than -10 then it underflows to zero*/ + if(act_exponent < (f8_denormal_act_exponent - Wm)) + { + return NegativeZeroNan ? 0x00 : ((sign) ? 0x80 : 0x00); + } + else if(act_exponent <= f8_denormal_act_exponent) { - /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. - For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16 - actual exponent is -7, it is actually larger due to the implict 1, - Therefore it needs to be adjust to -6 and mantissa shift right by 1. - So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */ exponent_diff = f8_denormal_act_exponent - act_exponent; } else diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index a12e69d7e80..3da1a51588c 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -150,7 +150,8 @@ struct npy_format_descriptor static std::string format() { // following: https://docs.python.org/3/library/struct.html#format-characters - return "B"; + // TODO: need to figure out correct encoding + return "z"; } static constexpr auto name() { return _("fp8e4m3fnuz"); } }; From 1cf87efbd895623cac5e9b6a45d8b15ddaa3ba91 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 16 Nov 2023 21:48:14 +0000 Subject: [PATCH 024/115] use std::min to avoid sanitization errors --- src/include/migraphx/float8_impl.hpp | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/include/migraphx/float8_impl.hpp b/src/include/migraphx/float8_impl.hpp index f93e5a9a399..98c7dc70446 100644 --- a/src/include/migraphx/float8_impl.hpp +++ b/src/include/migraphx/float8_impl.hpp @@ -22,6 +22,8 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP #define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP +#include +#include #include #include #include @@ -149,19 +151,13 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) else { // fp32/fp16 is normal with implicit 1 act_exponent = exponent - bias; - /* - check if FP8 is underflowing to 0.0. Wm is added to check to allow FP8 to go into denorm - range. e.g. act_exponent for FP32/16 is -9 and e4m3fnuz has denorm_act exponent = -7 in - that case fp32/16 mantissa can be shifted right by two to make - exponent -7 and then it can be representable as e4m3fnuz denorm. So for fp32/fp16, exponent - -10 is the cut point to convert to e4m3fp8fnuz due to implicit 1 in mantissa. If fp32/16 - act_exponent is less than -10 then it underflows to zero*/ - if(act_exponent < (f8_denormal_act_exponent - Wm)) - { - return NegativeZeroNan ? 0x00 : ((sign) ? 0x80 : 0x00); - } - else if(act_exponent <= f8_denormal_act_exponent) + if(act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */ exponent_diff = f8_denormal_act_exponent - act_exponent; } else @@ -173,8 +169,8 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) mantissa += (1u << mfmt); // Add the implicit 1 into mantissa } - bool midpoint = (mantissa & ((1 << (mfmt - Wm + exponent_diff)) - 1)) == - (1 << (mfmt - Wm + exponent_diff - 1)); + bool midpoint = (mantissa & ((1u << std::min(32u, mfmt - Wm + exponent_diff)) - 1)) == + (1u << std::min(32u, mfmt - Wm + exponent_diff - 1)); /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift right as shift right could rip off some residual part and make something not midpoint look like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than From 98a838f447b0599af65e2d566468097cbfac66fb Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 16 Nov 2023 22:40:32 +0000 Subject: [PATCH 025/115] formatting --- src/py/migraphx_py.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 3da1a51588c..3b95959f98d 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -150,7 +150,7 @@ struct npy_format_descriptor static std::string format() { // following: https://docs.python.org/3/library/struct.html#format-characters - // TODO: need to figure out correct encoding + // TODO: need to figure out correct encoding return "z"; } static constexpr auto name() { return _("fp8e4m3fnuz"); } From 61e4e1d7ee340c1623931e0aea170e12edceed11 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 16 Nov 2023 23:20:06 +0000 Subject: [PATCH 026/115] use 31 for min value --- src/include/migraphx/float8_impl.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/include/migraphx/float8_impl.hpp b/src/include/migraphx/float8_impl.hpp index 98c7dc70446..178fbd39ce8 100644 --- a/src/include/migraphx/float8_impl.hpp +++ b/src/include/migraphx/float8_impl.hpp @@ -168,9 +168,9 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) } mantissa += (1u << mfmt); // Add the implicit 1 into mantissa } - - bool midpoint = (mantissa & ((1u << std::min(32u, mfmt - Wm + exponent_diff)) - 1)) == - (1u << std::min(32u, mfmt - Wm + exponent_diff - 1)); + // shifting by more than sizeof(T) is undefined behaviour, cap shift to 31 + bool midpoint = (mantissa & ((1u << std::min(31u, mfmt - Wm + exponent_diff)) - 1)) == + (1u << std::min(31u, mfmt - Wm + exponent_diff - 1)); /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift right as shift right could rip off some residual part and make something not midpoint look like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than @@ -178,7 +178,7 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) */ if(exponent_diff > 0) - mantissa >>= exponent_diff; + mantissa >>= std::min(31u, uint32_t(exponent_diff)); else if(exponent_diff == -1) mantissa <<= -exponent_diff; bool implicit_one = mantissa & (1 << mfmt); From a5c38ebea7a0287872f421f872b548bd74e0213b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 16 Nov 2023 23:20:46 +0000 Subject: [PATCH 027/115] add note --- src/include/migraphx/float8_impl.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/include/migraphx/float8_impl.hpp b/src/include/migraphx/float8_impl.hpp index 178fbd39ce8..1270cacdb9a 100644 --- a/src/include/migraphx/float8_impl.hpp +++ b/src/include/migraphx/float8_impl.hpp @@ -168,7 +168,9 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) } mantissa += (1u << mfmt); // Add the implicit 1 into mantissa } - // shifting by more than sizeof(T) is undefined behaviour, cap shift to 31 + + // need to know whether the number is right in the middle of two adjacent fp8 numbers. use max + // value of 31 to avoid undefined behaviour bool midpoint = (mantissa & ((1u << std::min(31u, mfmt - Wm + exponent_diff)) - 1)) == (1u << std::min(31u, mfmt - Wm + exponent_diff - 1)); /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we From 017d67e26f3371355726b929b47be48691ba87fc Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 01:29:14 +0000 Subject: [PATCH 028/115] add some more comments --- src/include/migraphx/float8_impl.hpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/include/migraphx/float8_impl.hpp b/src/include/migraphx/float8_impl.hpp index 1270cacdb9a..e6423eea83a 100644 --- a/src/include/migraphx/float8_impl.hpp +++ b/src/include/migraphx/float8_impl.hpp @@ -192,6 +192,19 @@ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) uint32_t drop_mask = (1u << (mfmt - Wm)) - 1; bool odd = mantissa & (1u << (mfmt - Wm)); // if the least significant bit that is not truncated is 1 + /* + This part is doing rounding by adding mantissa part that is going to get dropped. + e.g. if the dropped part for less than 0.5 than it would round down. + if the dropped part is more than 0.5 then it would round up by rolling carry to LSB of retained + mantissa. + For the mid point when bit pattern is like this for Odd: `xy1:10000000` for Odd and + `xy0:10000000` for the Even. where `:` is delimiter for dropped v/s retained part. + For the odd case : + this will add xy1:10000000 + 000:10000000 which would roll over carry to LSB of retained + part making it RNE. + For the even case : this will add xy0:10000000 + 000:01111111 which would + round down and keep number Even + */ mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; // Now we deal with overflow From a9dd42f74f541649b577c212f9caeea1f18b8cde Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 14:15:07 +0000 Subject: [PATCH 029/115] port gpu changes --- src/targets/gpu/compile_gen.cpp | 12 +- src/targets/gpu/compile_hip.cpp | 17 +- src/targets/gpu/compile_hip_code_object.cpp | 1 + .../include/migraphx/kernels/bit_cast.hpp | 33 + .../include/migraphx/kernels/float8.hpp | 659 ++++++++++++++++++ .../include/migraphx/kernels/float8_impl.hpp | 329 +++++++++ .../kernels/include/migraphx/kernels/hip.hpp | 2 +- .../kernels/include/migraphx/kernels/math.hpp | 51 +- .../include/migraphx/kernels/type_traits.hpp | 8 +- .../include/migraphx/kernels/types.hpp | 7 +- .../include/migraphx/kernels/vectorize.hpp | 1 + src/targets/gpu/target.cpp | 1 + test/gpu/jit.cpp | 22 +- test/verify/test_abs.cpp | 9 +- test/verify/test_acos.cpp | 9 +- test/verify/test_add.cpp | 9 +- test/verify/test_literal_limits.cpp | 7 +- 17 files changed, 1134 insertions(+), 43 deletions(-) create mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp create mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp create mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index d136d180086..df3c59cc2f7 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis, const std::vector& inputs, const std::vector& sizes) { + // disable vectorization for fp8 types + if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) { + return ishape.type() == migraphx::shape::fp8e4m3fnuz_type; + })) + return {1, axis}; if(std::all_of( inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; })) return {1, axis}; @@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis, vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector& inputs) { + // disable vectorization for fp8 types + if(std::any_of(inputs.begin(), inputs.end(), [&](auto ishape) { + return ishape.type() == migraphx::shape::fp8e4m3fnuz_type; + })) + return {1, axis}; if(inputs.empty()) return {1, axis}; std::size_t n = std::max_element(inputs.begin(), @@ -305,7 +315,7 @@ std::string generate_reduce(const module& m, const std::string& name) std::transform( params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; }); return interpolate_string(inner_template, - {{"inner", inner_name}, + {{"inner", inner_name}, {"params", join_strings(params, ", ")}, {"args", join_strings(args, ", ")}, {"call", call_function}}); diff --git a/src/targets/gpu/compile_hip.cpp b/src/targets/gpu/compile_hip.cpp index e58b681b563..51dbe4d48ea 100644 --- a/src/targets/gpu/compile_hip.cpp +++ b/src/targets/gpu/compile_hip.cpp @@ -199,7 +199,7 @@ std::vector> compile_hip_src_with_hiprtc(std::vector& srcs, std::string params, const std std::cout << std::string(src.content) << std::endl; } } - auto fname = fs::path{"migraphx-hiprtc-driver"}; -#ifdef _WIN32 - fname.replace_extension(".exe"); -#endif auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc); - auto driver = p.parent_path() / fname; - - bool found = fs::exists(driver); - if(not found) - { - driver = p.parent_path().parent_path() / "bin" / fname; - found = fs::exists(driver); - } + auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver"; - if(found) + if(fs::exists(driver)) { value v; v["srcs"] = to_value(hsrcs); diff --git a/src/targets/gpu/compile_hip_code_object.cpp b/src/targets/gpu/compile_hip_code_object.cpp index d2c7dfc8fda..fea46505a8b 100644 --- a/src/targets/gpu/compile_hip_code_object.cpp +++ b/src/targets/gpu/compile_hip_code_object.cpp @@ -197,6 +197,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); + options.params += " -D__HIP_NO_F8_CONVERSIONS__=1"; options.params += " " + join_strings(compiler_warnings(), " "); options.params += " -ftemplate-backtrace-limit=0"; options.params += " -Werror"; diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp new file mode 100644 index 00000000000..b77c7dc73f4 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp @@ -0,0 +1,33 @@ +/* ************************************************************************ + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ +#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP +#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP + +namespace migraphx { +template +inline constexpr To bit_cast(From fr) noexcept +{ + static_assert(sizeof(To) == sizeof(From)); + return __builtin_bit_cast(To, fr); +} +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp new file mode 100644 index 00000000000..bcc3ae26a61 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -0,0 +1,659 @@ +/* ************************************************************************ + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ + +#ifndef MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP +#define MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" +#pragma clang diagnostic ignored "-Wfloat-equal" +#pragma clang diagnostic ignored "-Wmacro-redefined" +#pragma clang diagnostic ignored "-Wc++20-extensions" +#endif // __clang__ + +#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__)) +// need to include hip_runtime.h otherwise it complains about __host__ and __device__ +#if defined(MIGRAPHX_JIT_USE_HIPRTC) +#include +#else +#include +#endif +#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__ +#define MIGRAPHX_HIP_HOST __host__ +#else +#define MIGRAPHX_HIP_HOST_DEVICE +#define MIGRAPHX_HIP_HOST +#endif // HIP_PLATFORM_AMD + +#define MIGRAPHX_HIP_DEVICE __device__ + +#ifndef MIGRAPHX_FP8_FNUZ +#define MIGRAPHX_FP8_FNUZ true +#endif // MIGRAPHX_FP8_FNUZ + +// We are clipping in down conversion by default +#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 +#if defined(MIGRAPHX_JIT_USE_HIPRTC) +#include +using uint8_t = migraphx::uint8_t; +using uint16_t = migraphx::uint16_t; +using uint32_t = migraphx::uint32_t; +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#include + +namespace migraphx { +namespace fp8 { + +enum class rounding_mode +{ + standard, // standard rounding is doing RNE -- round to nearest even + stochastic +}; + +enum class f8_type +{ + bf8 = 0, // s1e5m2 + fp8 = 1 // s1e4m3 +}; + +template +class numeric_limits; + +template +struct float8 +{ + uint8_t data; + // default constructor + MIGRAPHX_HIP_HOST_DEVICE constexpr float8() = default; + // default copy constructor + MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8& y) = default; + struct from_bits_t + { + }; + static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); } + + MIGRAPHX_HIP_HOST_DEVICE explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} + +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // device specific optimized F8 down-conversion code + + template + static MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0) + { + uint8_t i8data; + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + +#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING + if constexpr(T == migraphx::fp8::f8_type::fp8) + { + if((val.i32val & 0x7F800000) != 0x7F800000) /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + else + { + if((val.i32val & 0x7F800000) != 0x7F800000) // propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0); + } +#endif + if(stochastic_rounding) + { + if constexpr(T == migraphx::fp8::f8_type::fp8) + { + ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos + } + else + { + ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + } + } + else // RNE CVT + { + if constexpr(T == migraphx::fp8::f8_type::fp8) + { + ival = __builtin_amdgcn_cvt_pk_fp8_f32( + val.fval, val.fval, ival, false); // false -> WORD0 + } + else + { + ival = __builtin_amdgcn_cvt_pk_bf8_f32( + val.fval, val.fval, ival, false); // false -> WORD0} + } + } + val.i32val = ival; + i8data = val.i8val[0]; // little endian + + return i8data; + } +#endif // __gfx940__ + + // constructor from float +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + + // NOTE: ON-DEVICE... always optimal bias + explicit MIGRAPHX_HIP_DEVICE + float8(float v, + migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, + uint32_t rng = 0) + { + // runtime branch, use cast_to_f8_from_f32 if want to avoid it + if(rm == migraphx::fp8::rounding_mode::stochastic) + data = cast_to_f8_from_f32(v, rng); + else + data = cast_to_f8_from_f32(v); + } + + // Host only implementation using s/w simulation + explicit MIGRAPHX_HIP_HOST +#else + // both Host and DEVICE for non-gfx940 using s/w simulation + explicit constexpr MIGRAPHX_HIP_HOST_DEVICE +#endif + float8(float v, + migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, + uint32_t rng = 0) + { + if constexpr(T == migraphx::fp8::f8_type::fp8) + { +#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#else // MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING + } + else + { +#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#else // MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#endif // rocblas_F8_downcast_clipping} + } + } + + /* + // Constructor from half + explicit constexpr MIGRAPHX_HIP_HOST_DEVICE + float8(migraphx::half v, + migraphx::fp8::rounding_mode rm = + migraphx::fp8::rounding_mode::standard, + uint32_t rng = 0) + : float8((float)v, rm, rng) + { + } + + // constructor from int + explicit constexpr MIGRAPHX_HIP_HOST_DEVICE + float8(int v, + migraphx::fp8::rounding_mode rm = + migraphx::fp8::rounding_mode::standard, + uint32_t rng = 0) + : float8((float)v, rm, rng) + { + } + + // constructor from double + explicit constexpr MIGRAPHX_HIP_HOST_DEVICE + float8(double v, + migraphx::fp8::rounding_mode rm = + migraphx::fp8::rounding_mode::standard, + uint32_t rng = 0) + : float8((float)v, rm, rng) + { + } + */ + /**/ + // convert to float +// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if 0 // need constexpr operator(). This version can't be constexpr + // upcast using device specific intrinsic + inline MIGRAPHX_HIP_DEVICE operator float() const + { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + if constexpr(T == migraphx::fp8::f8_type::fp8) + { + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + } + else + { + asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + } + + return fval; + } + + inline constexpr MIGRAPHX_HIP_HOST operator float() const +#else // non gfx940 + inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const +#endif + { + if constexpr(T == migraphx::fp8::f8_type::fp8) + { + return migraphx::fp8::impl:: + cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data); + } // else + return migraphx::fp8::impl:: + cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data); + } + + /* + // convert to half + explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const + { + return migraphx::half(float(*this)); // convert to float, then convert to f16 + } + */ + + // check for zero + inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const + { + if constexpr(MIGRAPHX_FP8_FNUZ) + { + return data == 0x00; + } + else + { + return (data == 0x00) || (data == 0x80); + } + } + + // check for nan + inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const + { + if constexpr(MIGRAPHX_FP8_FNUZ) + { + return data == 0x80; + } + else + { + if(T == migraphx::fp8::f8_type::bf8) + { + return (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xfd) || + (data == 0xfe) || (data == 0xff); + } + else + { + return (data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) || + (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xf9) || + (data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) || + (data == 0xfe) || (data == 0xff); + } + } + } + + // check for inf + inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const + { + if constexpr(MIGRAPHX_FP8_FNUZ) + { + return data == 0x80; + } + else + { + if(T == migraphx::fp8::f8_type::bf8) + { + return (data == 0x7c) || (data == 0xfc); + } + else + { + return (data == 0x78) || (data == 0xf8); + } + } + } + +#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \ + constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float8& rhs) \ + { \ + const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ + *this = static_cast(tmp); \ + return *this; \ + } \ + constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \ + { \ + const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ + *this = static_cast(tmp); \ + return *this; \ + } + + MIGRAPHX_FP8_UNARY_OP(*=, *) + MIGRAPHX_FP8_UNARY_OP(-=, -) + MIGRAPHX_FP8_UNARY_OP(+=, +) + MIGRAPHX_FP8_UNARY_OP(/=, /) + + inline MIGRAPHX_HIP_HOST_DEVICE constexpr float8& operator=(const float8& rhs) = default; + inline MIGRAPHX_HIP_HOST_DEVICE constexpr float8& operator=(float8&& rhs) = default; + +#if !defined(__HIP_NO_F8_CONVERSIONS__) + // for the device kernels, this needs to be disabled since implicit_conversion op can type cast + // any type to any other type and that results in conflicts in candidate overload resolutions. + inline constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs) + { + *this = static_cast(rhs); + return *this; + } +#endif + + inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const float8& rhs) const + { + if((rhs.is_zero() && this->is_zero()) || + (fabs(rhs - *this) < migraphx::fp8::numeric_limits>::epsilon())) + return true; + else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf()) + return false; + + return false; + } + + inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator<(const float8& rhs) const + { + const auto we = static_cast(*this); + const auto them = static_cast(rhs); + return we < them; + } + + inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator>(const float8& rhs) const + { + const auto we = static_cast(*this); + const auto them = static_cast(rhs); + return we > them; + } +}; + +#ifndef MIGRAPHX_JIT_USE_HIPRTC +// Special operator overloading +template +inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8& rhs) +{ + return os << static_cast(rhs); +} +#endif + +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ + template \ + inline constexpr U MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \ + const migraphx::fp8::float8& lhs, const migraphx::fp8::float8& rhs) \ + { \ + return U(static_cast(lhs) binary_op static_cast(rhs)); \ + } + +// TODO: these should return floats +MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8::float8) +MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8::float8) +MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8::float8) +MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8::float8) +// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects. +MIGRAPHX_FP8_BINARY_OP(==, bool) +MIGRAPHX_FP8_BINARY_OP(>=, bool) +MIGRAPHX_FP8_BINARY_OP(<=, bool) +MIGRAPHX_FP8_BINARY_OP(>, bool) +MIGRAPHX_FP8_BINARY_OP(<, bool) +MIGRAPHX_FP8_BINARY_OP(!=, bool) + +template +inline MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 fabs(migraphx::fp8::float8 v) +{ + v.data = v.data & 0x7f; + return v; +} + +template +MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Max() +{ + return T{0x7F, T::from_bits()}; +} + +template +MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest() +{ + return T{0xFF, T::from_bits()}; +} + +using fp8e4m3fnuz = float8; + +template <> +class numeric_limits> +{ + public: + // TODO :figure out epsilon in Hex to make it constexpr + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + epsilon() + { + return migraphx::fp8::float8( + 0x28, migraphx::fp8::float8<>::from_bits()); + } + + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + quiet_NaN() + { + return migraphx::fp8::float8( + MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx::fp8::float8<>::from_bits()); + } + + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + max() + { + return migraphx::fp8::F8_Max>(); + } + + // TODO figure out Hex value + static MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 min() + { + return static_cast>(-1.0f) * + migraphx::fp8::F8_Max>(); + } + + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + lowest() + { + return migraphx::fp8::F8_Lowest>(); + } + + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + infinity() + { + return migraphx::fp8::float8( + MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx::fp8::float8<>::from_bits()); + } +}; + +template <> +class numeric_limits> +{ + public: + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + epsilon() + { + return migraphx::fp8::float8( + 0x34, migraphx::fp8::float8::from_bits()); + } + + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + quiet_NaN() + { + return migraphx::fp8::float8( + MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d, + migraphx::fp8::float8::from_bits()); + } + + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + max() + { + return static_cast>( + migraphx::fp8::F8_Max>()); + } + // TODO figure out constexpr value + static MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 min() + { + return static_cast>(float(-1.0f)) * + migraphx::fp8::F8_Max>(); + } + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + lowest() + { + return migraphx::fp8::F8_Lowest>(); + } + + static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 + infinity() + { + return migraphx::fp8::float8( + MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c, + migraphx::fp8::float8::from_bits()); + } +}; +/* +// Use h/w intrinsic and optimized version when __gfx940__ +template {}) && + (migraphx::is_same{} || + migraphx::is_same{})), + int>::type = 0> +inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // NOTE: we are directly calling cast_to_f8_from_f32 instead of constructor to optimize + // away one runtime branch + T val; + if(migraphx::is_same::value) + val.data = migraphx_f8::cast_to_f8_from_f32(float(a), rng); + else + val.data = migraphx_bf8::cast_to_bf8_from_f32(float(a), rng); + return val; +#else // non gfx940 + return T(float(a), + stochastic_rounding ? migraphx::fp8::rounding_mode::stochastic + : migraphx::fp8::rounding_mode::standard, + rng); +#endif // __gfx940__ +} + +// NOTE NOTE: The above code is good if we don't consider HIP-GEMM code and only consider +// the quantization However, if we need HIP-GEMM for fall-back, we would need explicit_cast +// handles Tacc=f32 to To=f16/bf16 conversion +template {}) && + !(migraphx::is_same{} || + migraphx::is_same{})), + int>::type = 0> +inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng) +{ + // the return type is not a F8 types, no SR for those types + // not sure if we have direct conversion, so converting to float first + // no effect if the input type is float + return T(float(a)); +} +*/ +} // namespace fp8 +} // namespace migraphx +// define numeric limits for the new data type +#ifndef MIGRAPHX_JIT_USE_HIPRTC +namespace std { +inline bool isfinite(migraphx::fp8::float8 x) // NOLINT +{ + return x.is_inf(); +} + +inline bool isfinite(migraphx::fp8::float8 x) // NOLINT +{ + return x.is_inf(); +} + +inline bool isnan(migraphx::fp8::float8 x) // NOLINT +{ + return x.is_nan(); +} + +inline bool isnan(migraphx::fp8::float8 x) // NOLINT +{ + return x.is_nan(); +} + +template <> +class numeric_limits> + : public migraphx::fp8::numeric_limits> +{ +}; + +template <> +class numeric_limits> + : public migraphx::fp8::numeric_limits> +{ +}; + +template +struct common_type : std::common_type // NOLINT +{ +}; + +template +struct common_type : std::common_type // NOLINT +{ +}; + +template <> +struct common_type +{ + using type = float; +}; + +} // namespace std +#endif +// ================================================================================================= +#if defined(__clang__) +#pragma clang diagnostic pop +#endif +#endif // MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp new file mode 100644 index 00000000000..ae0af47f155 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp @@ -0,0 +1,329 @@ +/* ************************************************************************ + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ + +#ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP +#define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wreserved-identifier" +#endif + +#define CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) +namespace migraphx { +namespace detail { +template +struct conditional +{ + using type = T; +}; + +template +struct conditional +{ + using type = F; +}; + +template +inline constexpr To bit_cast(From fr) noexcept +{ + static_assert(sizeof(To) == sizeof(From)); +#if defined(__GNUC__) and !defined(__clang__) + To x = CONST_FOLD(*reinterpret_cast(&fr)); +#else + To x = __builtin_bit_cast(To, fr); +#endif + return x; +} +} // namespace detail + +namespace fp8 { +namespace impl { +// #ifdef __HIP_PLATFORM_HCC__ +// __device__ inline int clz(uint32_t x) { return __clz(x); } +// #else +// __host__ inline int clz(uint32_t x) { return __builtin_clz(x); } +// #endif + +template +MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) +{ + + static_assert(wm + we == 7, "wm+we==7"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + typename migraphx::detail::conditional::type x; + + if constexpr(sizeof(T) == 4) + x = migraphx::detail::bit_cast(_x); + else + x = migraphx::detail::bit_cast(_x); + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if constexpr(sizeof(T) == 4) + { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } + else + { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if(negative_zero_nan) + { + if(sizeof(T) == 4) + { + if((x & 0x7F800000) == 0x7F800000) + return 0x80; + } + else + { + // if(__hisinf(x) || __hisnan(x)) + if((x & 0x7C00) == 0x7C00) + return 0x80; + } + } + else + { + if(sizeof(T) == 4) + { + if((x & 0x7F800000) == 0x7F800000) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + else + { + if((x & 0x7C00) == 0x7C00) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + // handle positive zero + if(x == 0) + return 0; + // handle negative zero + if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) + { + if(negative_zero_nan) + { + return 0; + } + else + { + return 0x80; + } + } + + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In +this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= f8_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + (1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. +*/ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + f8_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = + mantissa & (1 << (mfmt - wm)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(f8_exponent == 0) + { + if((1 << mfmt) & mantissa) + { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } + else + { + if((1 << (mfmt + 1)) & mantissa) + { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if(f8_exponent > max_exp) + { + if(clip) + { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } + else + { + return signed_inf; + } + } + + if(f8_exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << 7); + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +template +MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) +{ + constexpr int weo = 8; + constexpr int wmo = 23; + + T fInf, fNegInf, fNaN, fNeg0; + uint32_t ifInf = 0x7F800000; + uint32_t ifNegInf = 0xFF800000; + uint32_t ifNaN = 0x7F800001; + uint32_t ifNeg0 = 0x80000000; + // TODO: need to change T for half but right now it would never called with half + fInf = migraphx::detail::bit_cast(ifInf); + fNegInf = migraphx::detail::bit_cast(ifNegInf); + fNaN = migraphx::detail::bit_cast(ifNaN); + fNeg0 = migraphx::detail::bit_cast(ifNeg0); + + if(x == 0) + return 0; + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if(negative_zero_nan) + { + if(x == 0x80) + return fNaN; + } + else + { + if(x == 0x80) + return fNeg0; + if(exponent == ((1 << we) - 1)) + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + typename migraphx::detail::conditional::type retval; + + const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __builtin_clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if(sizeof(T) == 2) + retval = (sign << 15) | (exponent << 10) | mantissa; + else + retval = (sign << 31) | (exponent << 23) | mantissa; + return migraphx::detail::bit_cast(retval); +} +} // namespace impl +} // namespace fp8 +} // namespace migraphx +#if defined(__clang__) +#pragma clang diagnostic pop +#endif +#endif // MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp index e9407d1ef66..c999487c85b 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp @@ -24,7 +24,7 @@ #ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP #define MIGRAPHX_GUARD_KERNELS_HIP_HPP -#ifndef MIGRAPHX_USE_HIPRTC +#ifndef MIGRAPHX_JIT_USE_HIPRTC #include #include #include diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index fec7a1acf52..50e815a6c8a 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -34,6 +34,9 @@ namespace migraphx { namespace math { constexpr float as_float(migraphx::half x) { return x; } + +constexpr float as_float(migraphx::fp8::fp8e4m3fnuz x) { return x; } + template constexpr T as_float(T x) { @@ -57,14 +60,14 @@ constexpr T as_float(T x) // NOLINTNEXTLINE #define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \ template ())> \ - auto __device__ name(type x, Ts... xs)->type \ + auto __device__ name(type x, Ts... xs) -> type \ { \ return fname(x, xs...); \ } // NOLINTNEXTLINE #define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \ - inline auto __device__ name(type x, type y)->type { return fname(x, y); } + inline auto __device__ name(type x, type y) -> type { return fname(x, y); } // NOLINTNEXTLINE #define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \ @@ -72,6 +75,20 @@ constexpr T as_float(T x) auto __device__ name(migraphx::half x, Ts... xs) \ MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...)) +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \ + template ())> \ + auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \ + migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...))) + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \ + inline auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, migraphx::fp8::fp8e4m3fnuz y) \ + -> migraphx::fp8::fp8e4m3fnuz \ + { \ + return migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \ + } + // Template with two overloads for math functions, one for half2 type and one for more generic // vectorization where N is 4 or another even number. @@ -162,6 +179,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod) +// use float to compute fp8 overload +MIGRAPHX_DEVICE_MATH_FP8(abs, ::abs) +MIGRAPHX_DEVICE_MATH_FP8(acos, ::acos) +MIGRAPHX_DEVICE_MATH_FP8(acosh, ::acosh) +MIGRAPHX_DEVICE_MATH_FP8(asin, ::asin) +MIGRAPHX_DEVICE_MATH_FP8(asinh, ::asinh) +MIGRAPHX_DEVICE_MATH_FP8(atan, ::atan) +MIGRAPHX_DEVICE_MATH_FP8(atanh, ::atanh) +MIGRAPHX_DEVICE_MATH_FP8(ceil, ::ceil) +MIGRAPHX_DEVICE_MATH_FP8(cos, ::cos) +MIGRAPHX_DEVICE_MATH_FP8(cosh, ::cosh) +MIGRAPHX_DEVICE_MATH_FP8(erf, ::erf) +MIGRAPHX_DEVICE_MATH_FP8(exp, ::exp) +MIGRAPHX_DEVICE_MATH_FP8(floor, ::floor) +MIGRAPHX_DEVICE_MATH_FP8(isnan, ::isnan) +MIGRAPHX_DEVICE_MATH_FP8(log, ::log) +MIGRAPHX_DEVICE_MATH_FP8(pow, ::pow) +MIGRAPHX_DEVICE_MATH_FP8(remainder, ::remainder) +MIGRAPHX_DEVICE_MATH_FP8(round, ::round) +MIGRAPHX_DEVICE_MATH_FP8(rsqrt, ::rsqrt) +MIGRAPHX_DEVICE_MATH_FP8(sin, ::sin) +MIGRAPHX_DEVICE_MATH_FP8(sinh, ::sinh) +MIGRAPHX_DEVICE_MATH_FP8(sqrt, ::sqrt) +MIGRAPHX_DEVICE_MATH_FP8(tan, ::tan) +MIGRAPHX_DEVICE_MATH_FP8(tanh, ::tanh) +MIGRAPHX_DEVICE_MATH_FP8(fmod, ::fmod) + // Map math functions to hip half2 functions // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names @@ -195,6 +239,9 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin) +MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(max, ::max) +MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(min, ::min) + template ())> constexpr auto max(const T& a, const T& b) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp index 890e55837a3..d1642bb1399 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp @@ -26,6 +26,7 @@ #include #include +#include namespace migraphx { @@ -230,7 +231,8 @@ constexpr unsigned long int_max(unsigned long n) template {} or is_floating_point{} or - is_same{})> + is_same{} or + is_same{})> constexpr T numeric_max() { if constexpr(is_integral{}) @@ -246,6 +248,8 @@ constexpr T numeric_max() return __FLT_MAX__; else if constexpr(is_same{}) return __FLT16_MAX__; + else if constexpr(is_same{}) + return migraphx::fp8::F8_Max(); else return 0; } @@ -260,6 +264,8 @@ constexpr T numeric_lowest() else return -numeric_max() - 1; } + else if constexpr(is_same{}) + return migraphx::fp8::F8_Lowest(); else { return -numeric_max(); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp index 4f71d1985a1..6575d5b2bf0 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp @@ -23,12 +23,11 @@ */ #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP - #include namespace migraphx { -#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC) +#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_JIT_USE_HIPRTC) using int8_t = signed char; using uint8_t = unsigned char; using int16_t = signed short; @@ -37,7 +36,7 @@ using int32_t = signed int; using uint32_t = unsigned int; using int64_t = signed long long; using uint64_t = unsigned long long; -#elif defined(MIGRAPHX_USE_HIPRTC) +#elif defined(MIGRAPHX_JIT_USE_HIPRTC) using int8_t = __hip_int8_t; using uint8_t = __hip_uint8_t; using int16_t = __hip_int16_t; @@ -55,7 +54,7 @@ using int32_t = std::int32_t; using uint32_t = std::uint32_t; using int64_t = std::int64_t; using uint64_t = std::uint64_t; -#endif // MIGRAPHX_USE_HIPRTC +#endif // MIGRAPHX_JIT_USE_HIPRTC using index_int = uint32_t; using diff_int = int32_t; diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp index b456b5c6e45..b66f88b7383 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp @@ -24,6 +24,7 @@ #ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP #define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP +#include #include #include diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 95455ee9ef0..dc1f8cd7991 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -98,6 +98,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti ctx.set_exhaustive_tune_flag(options.exhaustive_tune); std::set unsupported_types(shape::types().begin(), shape::types().end()); unsupported_types.erase(shape::type_t::float_type); + unsupported_types.erase(shape::type_t::fp8e4m3fnuz_type); unsupported_types.erase(shape::type_t::half_type); unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::int8_type); diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index 2b407178681..30429d28a6a 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -144,7 +144,7 @@ extern "C" { __global__ void kernel(${type}* p) { auto x = *p; - *p = migraphx::implicit_conversion(migraphx::${invoke}); + *p = implicit_conversion(migraphx::${invoke}); } } @@ -348,18 +348,18 @@ TEST_CASE(compile_math) auto vec_sizes = {2, 4, 6}; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, - migraphx::shape::fp8e4m3fnuz_type, - migraphx::shape::tuple_type}, - t)) + if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) continue; auto name = migraphx::shape::cpp_type(t); if(t == migraphx::shape::half_type) name.insert(0, "migraphx::"); data_types.push_back(name); - migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) { - return "migraphx::vec<" + name + ", " + std::to_string(i) + ">"; - }); + if(t != migraphx::shape::fp8e4m3fnuz_type) + { + migraphx::transform(vec_sizes, std::back_inserter(data_types), [&](auto i) { + return "migraphx::vec<" + name + ", " + std::to_string(i) + ">"; + }); + } } migraphx::shape input{migraphx::shape::float_type, {5, 2}}; migraphx::gpu::hip_compile_options options; @@ -399,10 +399,7 @@ TEST_CASE(assert_type_min_max) migraphx::gpu::hip_compile_options options; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, - migraphx::shape::fp8e4m3fnuz_type, - migraphx::shape::tuple_type}, - t)) + if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) continue; auto name = migraphx::shape::cpp_type(t); if(t == migraphx::shape::half_type) @@ -429,7 +426,6 @@ TEST_CASE(assert_type_min_max) min = std::to_string(as.min()); max = std::to_string(as.max()); } - auto src = migraphx::interpolate_string(assert_template, {{"type", name}, {"max", max}, {"min", min}}); migraphx::shape input{migraphx::shape::float_type, {5, 2}}; diff --git a/test/verify/test_abs.cpp b/test/verify/test_abs.cpp index 435cc64c95c..d2b7a924713 100644 --- a/test/verify/test_abs.cpp +++ b/test/verify/test_abs.cpp @@ -27,14 +27,19 @@ #include #include -struct test_abs : verify_program +template +struct test_abs : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto x = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); mm->add_instruction(migraphx::make_op("abs"), x); return p; } }; + +template struct test_abs; +template struct test_abs; +template struct test_abs; diff --git a/test/verify/test_acos.cpp b/test/verify/test_acos.cpp index 873cc5ffafa..14629662381 100644 --- a/test/verify/test_acos.cpp +++ b/test/verify/test_acos.cpp @@ -27,15 +27,20 @@ #include #include -struct test_acos : verify_program +template +struct test_acos : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {16}}; + migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("acos"), x); return p; } }; + +template struct test_acos; +template struct test_acos; +template struct test_acos; diff --git a/test/verify/test_add.cpp b/test/verify/test_add.cpp index d560767484a..a2031f95c4c 100644 --- a/test/verify/test_add.cpp +++ b/test/verify/test_add.cpp @@ -27,16 +27,21 @@ #include #include -struct test_add : verify_program +template +struct test_add : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {3}}; + migraphx::shape s{DType, {8}}; auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); mm->add_instruction(migraphx::make_op("add"), x, y); return p; } }; + +template struct test_add; +template struct test_add; +template struct test_add; diff --git a/test/verify/test_literal_limits.cpp b/test/verify/test_literal_limits.cpp index f1ac9bf913c..fa0828585e1 100644 --- a/test/verify/test_literal_limits.cpp +++ b/test/verify/test_literal_limits.cpp @@ -35,7 +35,11 @@ struct test_literal_limits : verify_program> migraphx::program p; auto* mm = p.get_main_module(); auto input_s = migraphx::shape(Q, {3, 1}); - auto infinity_val = std::numeric_limits::infinity(); + auto infinity_val = std::numeric_limits::max(); + if constexpr(std::numeric_limits::has_infinity) + { + infinity_val = std::numeric_limits::infinity(); + } std::vector s_data{ infinity_val, static_cast(-infinity_val), std::numeric_limits::quiet_NaN()}; @@ -52,3 +56,4 @@ template struct test_literal_limits; template struct test_literal_limits; template struct test_literal_limits; template struct test_literal_limits; +template struct test_literal_limits; From d7339e8aa6efc50e1419c3c3eb4be1fbf7637883 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 14:23:36 +0000 Subject: [PATCH 030/115] use bit cast --- .../include/migraphx/kernels/float8_impl.hpp | 32 +++++-------------- test/gpu/jit.cpp | 2 +- 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp index ae0af47f155..0d9cfcbe0c2 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp @@ -22,12 +22,12 @@ #ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP #define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP +#include #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wreserved-identifier" #endif -#define CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) namespace migraphx { namespace detail { template @@ -42,26 +42,10 @@ struct conditional using type = F; }; -template -inline constexpr To bit_cast(From fr) noexcept -{ - static_assert(sizeof(To) == sizeof(From)); -#if defined(__GNUC__) and !defined(__clang__) - To x = CONST_FOLD(*reinterpret_cast(&fr)); -#else - To x = __builtin_bit_cast(To, fr); -#endif - return x; -} } // namespace detail namespace fp8 { namespace impl { -// #ifdef __HIP_PLATFORM_HCC__ -// __device__ inline int clz(uint32_t x) { return __clz(x); } -// #else -// __host__ inline int clz(uint32_t x) { return __builtin_clz(x); } -// #endif template MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) @@ -73,9 +57,9 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t typename migraphx::detail::conditional::type x; if constexpr(sizeof(T) == 4) - x = migraphx::detail::bit_cast(_x); + x = migraphx::bit_cast(_x); else - x = migraphx::detail::bit_cast(_x); + x = migraphx::bit_cast(_x); uint32_t head, mantissa; int exponent, bias; @@ -267,10 +251,10 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) uint32_t ifNaN = 0x7F800001; uint32_t ifNeg0 = 0x80000000; // TODO: need to change T for half but right now it would never called with half - fInf = migraphx::detail::bit_cast(ifInf); - fNegInf = migraphx::detail::bit_cast(ifNegInf); - fNaN = migraphx::detail::bit_cast(ifNaN); - fNeg0 = migraphx::detail::bit_cast(ifNeg0); + fInf = migraphx::bit_cast(ifInf); + fNegInf = migraphx::bit_cast(ifNegInf); + fNaN = migraphx::bit_cast(ifNaN); + fNeg0 = migraphx::bit_cast(ifNeg0); if(x == 0) return 0; @@ -318,7 +302,7 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) retval = (sign << 15) | (exponent << 10) | mantissa; else retval = (sign << 31) | (exponent << 23) | mantissa; - return migraphx::detail::bit_cast(retval); + return migraphx::bit_cast(retval); } } // namespace impl } // namespace fp8 diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index 30429d28a6a..b9558ad2495 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -144,7 +144,7 @@ extern "C" { __global__ void kernel(${type}* p) { auto x = *p; - *p = implicit_conversion(migraphx::${invoke}); + *p = migraphx::implicit_conversion(migraphx::${invoke}); } } From 60942349cabd13b879175954610cd7eb99da95b8 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 14:38:35 +0000 Subject: [PATCH 031/115] Make FNUZ template param and add numeric limits --- .../include/migraphx/kernels/float8.hpp | 174 +++++++++++------- 1 file changed, 105 insertions(+), 69 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index bcc3ae26a61..705046e7a32 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -46,10 +46,6 @@ #define MIGRAPHX_HIP_DEVICE __device__ -#ifndef MIGRAPHX_FP8_FNUZ -#define MIGRAPHX_FP8_FNUZ true -#endif // MIGRAPHX_FP8_FNUZ - // We are clipping in down conversion by default #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 #if defined(MIGRAPHX_JIT_USE_HIPRTC) @@ -90,14 +86,14 @@ enum class f8_type template class numeric_limits; -template +template struct float8 { uint8_t data; // default constructor MIGRAPHX_HIP_HOST_DEVICE constexpr float8() = default; // default copy constructor - MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8& y) = default; + MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8& y) = default; struct from_bits_t { }; @@ -195,11 +191,11 @@ struct float8 { #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx::fp8::impl:: - cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( + cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); #else // MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx::fp8::impl:: - cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( + cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); #endif // MIGRAPHX_F8_DOWNCAST_CLIPPING } @@ -207,11 +203,11 @@ struct float8 { #ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx::fp8::impl:: - cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>( + cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); #else // MIGRAPHX_F8_DOWNCAST_CLIPPING data = migraphx::fp8::impl:: - cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>( + cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); #endif // rocblas_F8_downcast_clipping} } @@ -278,11 +274,9 @@ struct float8 { if constexpr(T == migraphx::fp8::f8_type::fp8) { - return migraphx::fp8::impl:: - cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data); + return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data); } // else - return migraphx::fp8::impl:: - cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data); + return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); } /* @@ -296,7 +290,7 @@ struct float8 // check for zero inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const { - if constexpr(MIGRAPHX_FP8_FNUZ) + if constexpr(FNUZ) { return data == 0x00; } @@ -309,7 +303,7 @@ struct float8 // check for nan inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const { - if constexpr(MIGRAPHX_FP8_FNUZ) + if constexpr(FNUZ) { return data == 0x80; } @@ -333,7 +327,7 @@ struct float8 // check for inf inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const { - if constexpr(MIGRAPHX_FP8_FNUZ) + if constexpr(FNUZ) { return data == 0x80; } @@ -458,97 +452,139 @@ MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest() return T{0xFF, T::from_bits()}; } -using fp8e4m3fnuz = float8; - +// https://onnx.ai/onnx/technical/float8.html +using fp8e4m3fn = float8; +using fp8e5m2 = float8; +using fp8e4m3fnuz = float8; +using fp8e5m2fnuz = float8; template <> -class numeric_limits> +class numeric_limits { public: - // TODO :figure out epsilon in Hex to make it constexpr - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - epsilon() + static constexpr bool has_infinity = false; + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz epsilon() + { + return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); + } + // NOLINTNEXTLINE + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz quiet_NaN() { - return migraphx::fp8::float8( - 0x28, migraphx::fp8::float8<>::from_bits()); + return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - quiet_NaN() + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz max() + { + return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); + } + // this is min value that is not DeNorm. DeNorm min is 0x01 + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz min() { - return migraphx::fp8::float8( - MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx::fp8::float8<>::from_bits()); + return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - max() + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz lowest() { - return migraphx::fp8::F8_Max>(); + return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); } +}; - // TODO figure out Hex value - static MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 min() +template <> +class numeric_limits +{ + public: + static constexpr bool has_infinity = false; + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn epsilon() + { + return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); + } + // NOLINTNEXTLINE + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn quiet_NaN() { - return static_cast>(-1.0f) * - migraphx::fp8::F8_Max>(); + return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - lowest() + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn max() + { + return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); + } + // this is min value that is not DeNorm. DeNorm min is 0x01 + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn min() { - return migraphx::fp8::F8_Lowest>(); + return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - infinity() + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn lowest() { - return migraphx::fp8::float8( - MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7F, migraphx::fp8::float8<>::from_bits()); + return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); } }; template <> -class numeric_limits> +class numeric_limits { public: - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - epsilon() + static constexpr bool has_infinity = false; + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz epsilon() { - return migraphx::fp8::float8( - 0x34, migraphx::fp8::float8::from_bits()); + return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - quiet_NaN() + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT { - return migraphx::fp8::float8( - MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7d, - migraphx::fp8::float8::from_bits()); + return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - max() + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz max() { - return static_cast>( - migraphx::fp8::F8_Max>()); + return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); } - // TODO figure out constexpr value - static MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 min() + // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make + // this distinction. For the floating points we would end up using lowest most of the times. + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz min() { - return static_cast>(float(-1.0f)) * - migraphx::fp8::F8_Max>(); + return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - lowest() + + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz lowest() + { + return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); + } +}; + +template <> +class numeric_limits +{ + public: + static constexpr bool has_infinity = true; + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 epsilon() { - return migraphx::fp8::F8_Lowest>(); + return fp8e5m2(0x34, fp8e5m2::from_bits()); + } + // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 quiet_NaN() + { + return fp8e5m2(0xFF, fp8e5m2::from_bits()); + } // NOLINT + + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 max() + { + return fp8e5m2(0x7B, fp8e5m2::from_bits()); + } + // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make + // this distinction. For the floating points we would end up using lowest most of the times. + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 min() + { + return fp8e5m2(0x4, fp8e5m2::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 - infinity() + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 lowest() + { + return fp8e5m2(0xFB, fp8e5m2::from_bits()); + } + // 7C and FC both are infinity + static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 infinity() { - return migraphx::fp8::float8( - MIGRAPHX_FP8_FNUZ ? 0x80 : 0x7c, - migraphx::fp8::float8::from_bits()); + return fp8e5m2(0x7C, fp8e5m2::from_bits()); } }; /* From 78ec77ec59af7a1fa3947bf386d69e029a3a255c Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 14:46:29 +0000 Subject: [PATCH 032/115] only compile for device --- .../include/migraphx/kernels/float8.hpp | 115 ++++++++---------- .../include/migraphx/kernels/float8_impl.hpp | 4 +- 2 files changed, 54 insertions(+), 65 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 705046e7a32..36fba107a4f 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -30,19 +30,12 @@ #pragma clang diagnostic ignored "-Wc++20-extensions" #endif // __clang__ -#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__)) // need to include hip_runtime.h otherwise it complains about __host__ and __device__ #if defined(MIGRAPHX_JIT_USE_HIPRTC) #include #else #include #endif -#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__ -#define MIGRAPHX_HIP_HOST __host__ -#else -#define MIGRAPHX_HIP_HOST_DEVICE -#define MIGRAPHX_HIP_HOST -#endif // HIP_PLATFORM_AMD #define MIGRAPHX_HIP_DEVICE __device__ @@ -91,15 +84,15 @@ struct float8 { uint8_t data; // default constructor - MIGRAPHX_HIP_HOST_DEVICE constexpr float8() = default; + MIGRAPHX_HIP_DEVICE constexpr float8() = default; // default copy constructor - MIGRAPHX_HIP_HOST_DEVICE constexpr float8(const float8& y) = default; + MIGRAPHX_HIP_DEVICE constexpr float8(const float8& y) = default; struct from_bits_t { }; - static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); } + static constexpr MIGRAPHX_HIP_DEVICE from_bits_t from_bits() { return from_bits_t(); } - MIGRAPHX_HIP_HOST_DEVICE explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} + MIGRAPHX_HIP_DEVICE explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // device specific optimized F8 down-conversion code @@ -176,12 +169,9 @@ struct float8 else data = cast_to_f8_from_f32(v); } - - // Host only implementation using s/w simulation - explicit MIGRAPHX_HIP_HOST #else - // both Host and DEVICE for non-gfx940 using s/w simulation - explicit constexpr MIGRAPHX_HIP_HOST_DEVICE + // DEVICE for non-gfx940 using s/w simulation + explicit constexpr MIGRAPHX_HIP_DEVICE #endif float8(float v, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, @@ -215,7 +205,7 @@ struct float8 /* // Constructor from half - explicit constexpr MIGRAPHX_HIP_HOST_DEVICE + explicit constexpr MIGRAPHX_HIP_DEVICE float8(migraphx::half v, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, @@ -225,7 +215,7 @@ struct float8 } // constructor from int - explicit constexpr MIGRAPHX_HIP_HOST_DEVICE + explicit constexpr MIGRAPHX_HIP_DEVICE float8(int v, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, @@ -235,7 +225,7 @@ struct float8 } // constructor from double - explicit constexpr MIGRAPHX_HIP_HOST_DEVICE + explicit constexpr MIGRAPHX_HIP_DEVICE float8(double v, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, @@ -267,9 +257,8 @@ struct float8 return fval; } - inline constexpr MIGRAPHX_HIP_HOST operator float() const #else // non gfx940 - inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const + inline constexpr MIGRAPHX_HIP_DEVICE operator float() const #endif { if constexpr(T == migraphx::fp8::f8_type::fp8) @@ -281,14 +270,14 @@ struct float8 /* // convert to half - explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const + explicit inline MIGRAPHX_HIP_DEVICE operator migraphx::half() const { return migraphx::half(float(*this)); // convert to float, then convert to f16 } */ // check for zero - inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const + inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const { if constexpr(FNUZ) { @@ -301,7 +290,7 @@ struct float8 } // check for nan - inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const + inline MIGRAPHX_HIP_DEVICE constexpr bool is_nan() const { if constexpr(FNUZ) { @@ -325,7 +314,7 @@ struct float8 } // check for inf - inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const + inline MIGRAPHX_HIP_DEVICE constexpr bool is_inf() const { if constexpr(FNUZ) { @@ -345,13 +334,13 @@ struct float8 } #define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \ - constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float8& rhs) \ + constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float8& rhs) \ { \ const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ *this = static_cast(tmp); \ return *this; \ } \ - constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs) \ + constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float& rhs) \ { \ const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ *this = static_cast(tmp); \ @@ -363,20 +352,20 @@ struct float8 MIGRAPHX_FP8_UNARY_OP(+=, +) MIGRAPHX_FP8_UNARY_OP(/=, /) - inline MIGRAPHX_HIP_HOST_DEVICE constexpr float8& operator=(const float8& rhs) = default; - inline MIGRAPHX_HIP_HOST_DEVICE constexpr float8& operator=(float8&& rhs) = default; + inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default; + inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) = default; #if !defined(__HIP_NO_F8_CONVERSIONS__) // for the device kernels, this needs to be disabled since implicit_conversion op can type cast // any type to any other type and that results in conflicts in candidate overload resolutions. - inline constexpr float8& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs) + inline constexpr float8& MIGRAPHX_HIP_DEVICE operator=(float rhs) { *this = static_cast(rhs); return *this; } #endif - inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const float8& rhs) const + inline MIGRAPHX_HIP_DEVICE constexpr bool operator==(const float8& rhs) const { if((rhs.is_zero() && this->is_zero()) || (fabs(rhs - *this) < migraphx::fp8::numeric_limits>::epsilon())) @@ -387,14 +376,14 @@ struct float8 return false; } - inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator<(const float8& rhs) const + inline MIGRAPHX_HIP_DEVICE constexpr bool operator<(const float8& rhs) const { const auto we = static_cast(*this); const auto them = static_cast(rhs); return we < them; } - inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator>(const float8& rhs) const + inline MIGRAPHX_HIP_DEVICE constexpr bool operator>(const float8& rhs) const { const auto we = static_cast(*this); const auto them = static_cast(rhs); @@ -412,12 +401,12 @@ inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8 #endif // NOLINTNEXTLINE -#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ - template \ - inline constexpr U MIGRAPHX_HIP_HOST_DEVICE operator binary_op( \ - const migraphx::fp8::float8& lhs, const migraphx::fp8::float8& rhs) \ - { \ - return U(static_cast(lhs) binary_op static_cast(rhs)); \ +#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ + template \ + inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const migraphx::fp8::float8& lhs, \ + const migraphx::fp8::float8& rhs) \ + { \ + return U(static_cast(lhs) binary_op static_cast(rhs)); \ } // TODO: these should return floats @@ -434,20 +423,20 @@ MIGRAPHX_FP8_BINARY_OP(<, bool) MIGRAPHX_FP8_BINARY_OP(!=, bool) template -inline MIGRAPHX_HIP_HOST_DEVICE migraphx::fp8::float8 fabs(migraphx::fp8::float8 v) +inline MIGRAPHX_HIP_DEVICE migraphx::fp8::float8 fabs(migraphx::fp8::float8 v) { v.data = v.data & 0x7f; return v; } template -MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Max() +MIGRAPHX_HIP_DEVICE constexpr T F8_Max() { return T{0x7F, T::from_bits()}; } template -MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest() +MIGRAPHX_HIP_DEVICE constexpr T F8_Lowest() { return T{0xFF, T::from_bits()}; } @@ -462,27 +451,27 @@ class numeric_limits { public: static constexpr bool has_infinity = false; - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz epsilon() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); } // NOLINTNEXTLINE - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz quiet_NaN() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz max() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01 - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz min() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fnuz lowest() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); } @@ -493,27 +482,27 @@ class numeric_limits { public: static constexpr bool has_infinity = false; - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn epsilon() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); } // NOLINTNEXTLINE - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn quiet_NaN() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn max() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01 - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn min() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e4m3fn lowest() + static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn lowest() { return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); } @@ -524,28 +513,28 @@ class numeric_limits { public: static constexpr bool has_infinity = false; - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz epsilon() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz max() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this distinction. For the floating points we would end up using lowest most of the times. - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz min() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2fnuz lowest() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); } @@ -556,33 +545,33 @@ class numeric_limits { public: static constexpr bool has_infinity = true; - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 epsilon() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); } // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 quiet_NaN() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } // NOLINT - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 max() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this distinction. For the floating points we would end up using lowest most of the times. - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 min() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 lowest() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } // 7C and FC both are infinity - static constexpr MIGRAPHX_HIP_HOST_DEVICE fp8e5m2 infinity() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp index 0d9cfcbe0c2..e45eb832f17 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp @@ -48,7 +48,7 @@ namespace fp8 { namespace impl { template -MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) +__device__ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) { static_assert(wm + we == 7, "wm+we==7"); @@ -240,7 +240,7 @@ this case, the fp16 mantissa should be shift left by 1 */ } template -MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x) +__device__ constexpr T cast_from_f8(uint8_t x) { constexpr int weo = 8; constexpr int wmo = 23; From 3411649cc80aa14cf971d21899f28332400f2f74 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 15:09:33 +0000 Subject: [PATCH 033/115] remove non-JIT related code --- src/targets/gpu/compile_hip.cpp | 2 +- src/targets/gpu/compile_hip_code_object.cpp | 1 - .../include/migraphx/kernels/float8.hpp | 186 ++++-------------- .../kernels/include/migraphx/kernels/hip.hpp | 2 +- .../include/migraphx/kernels/types.hpp | 6 +- 5 files changed, 41 insertions(+), 156 deletions(-) diff --git a/src/targets/gpu/compile_hip.cpp b/src/targets/gpu/compile_hip.cpp index 51dbe4d48ea..c3e101a3ae1 100644 --- a/src/targets/gpu/compile_hip.cpp +++ b/src/targets/gpu/compile_hip.cpp @@ -199,7 +199,7 @@ std::vector> compile_hip_src_with_hiprtc(std::vector -#else -#include -#endif #define MIGRAPHX_HIP_DEVICE __device__ // We are clipping in down conversion by default #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 -#if defined(MIGRAPHX_JIT_USE_HIPRTC) + #include using uint8_t = migraphx::uint8_t; using uint16_t = migraphx::uint16_t; using uint32_t = migraphx::uint32_t; -#else -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#endif #include @@ -203,38 +186,6 @@ struct float8 } } - /* - // Constructor from half - explicit constexpr MIGRAPHX_HIP_DEVICE - float8(migraphx::half v, - migraphx::fp8::rounding_mode rm = - migraphx::fp8::rounding_mode::standard, - uint32_t rng = 0) - : float8((float)v, rm, rng) - { - } - - // constructor from int - explicit constexpr MIGRAPHX_HIP_DEVICE - float8(int v, - migraphx::fp8::rounding_mode rm = - migraphx::fp8::rounding_mode::standard, - uint32_t rng = 0) - : float8((float)v, rm, rng) - { - } - - // constructor from double - explicit constexpr MIGRAPHX_HIP_DEVICE - float8(double v, - migraphx::fp8::rounding_mode rm = - migraphx::fp8::rounding_mode::standard, - uint32_t rng = 0) - : float8((float)v, rm, rng) - { - } - */ - /**/ // convert to float // #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if 0 // need constexpr operator(). This version can't be constexpr @@ -268,14 +219,6 @@ struct float8 return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); } - /* - // convert to half - explicit inline MIGRAPHX_HIP_DEVICE operator migraphx::half() const - { - return migraphx::half(float(*this)); // convert to float, then convert to f16 - } - */ - // check for zero inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const { @@ -300,15 +243,12 @@ struct float8 { if(T == migraphx::fp8::f8_type::bf8) { - return (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xfd) || - (data == 0xfe) || (data == 0xff); + return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or + (data == 0xFE) or (data == 0xFF); } else { - return (data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) || - (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xf9) || - (data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) || - (data == 0xfe) || (data == 0xff); + return (data == 0x7F) or (data == 0xFF); } } } @@ -324,11 +264,12 @@ struct float8 { if(T == migraphx::fp8::f8_type::bf8) { - return (data == 0x7c) || (data == 0xfc); + return (data == 0x7C) or (data == 0xFC); } else { - return (data == 0x78) || (data == 0xf8); + // no infinities in e4m3fn, represent them as NaNs + return (data == 0x7F) or (data == 0xFF); } } } @@ -355,24 +296,12 @@ struct float8 inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default; inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) = default; -#if !defined(__HIP_NO_F8_CONVERSIONS__) - // for the device kernels, this needs to be disabled since implicit_conversion op can type cast - // any type to any other type and that results in conflicts in candidate overload resolutions. - inline constexpr float8& MIGRAPHX_HIP_DEVICE operator=(float rhs) - { - *this = static_cast(rhs); - return *this; - } -#endif - inline MIGRAPHX_HIP_DEVICE constexpr bool operator==(const float8& rhs) const { - if((rhs.is_zero() && this->is_zero()) || - (fabs(rhs - *this) < migraphx::fp8::numeric_limits>::epsilon())) - return true; - else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf()) + if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) return false; - + else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data)) + return true; return false; } @@ -391,15 +320,6 @@ struct float8 } }; -#ifndef MIGRAPHX_JIT_USE_HIPRTC -// Special operator overloading -template -inline std::ostream& operator<<(std::ostream& os, const migraphx::fp8::float8& rhs) -{ - return os << static_cast(rhs); -} -#endif - // NOLINTNEXTLINE #define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ template \ @@ -422,8 +342,32 @@ MIGRAPHX_FP8_BINARY_OP(>, bool) MIGRAPHX_FP8_BINARY_OP(<, bool) MIGRAPHX_FP8_BINARY_OP(!=, bool) -template -inline MIGRAPHX_HIP_DEVICE migraphx::fp8::float8 fabs(migraphx::fp8::float8 v) +// https://onnx.ai/onnx/technical/float8.html +using fp8e4m3fn = float8; +using fp8e5m2 = float8; +using fp8e4m3fnuz = float8; +using fp8e5m2fnuz = float8; +; + +inline MIGRAPHX_HIP_DEVICE fp8e4m3fnuz fabs(fp8e4m3fnuz v) +{ + v.data = v.data & 0x7f; + return v; +} + +inline MIGRAPHX_HIP_DEVICE fp8e4m3fn fabs(fp8e4m3fn v) +{ + v.data = v.data & 0x7f; + return v; +} + +inline MIGRAPHX_HIP_DEVICE fp8e5m2fnuz fabs(fp8e5m2fnuz v) +{ + v.data = v.data & 0x7f; + return v; +} + +inline MIGRAPHX_HIP_DEVICE fp8e5m2 fabs(fp8e5m2 v) { v.data = v.data & 0x7f; return v; @@ -441,11 +385,6 @@ MIGRAPHX_HIP_DEVICE constexpr T F8_Lowest() return T{0xFF, T::from_bits()}; } -// https://onnx.ai/onnx/technical/float8.html -using fp8e4m3fn = float8; -using fp8e5m2 = float8; -using fp8e4m3fnuz = float8; -using fp8e5m2fnuz = float8; template <> class numeric_limits { @@ -624,59 +563,6 @@ inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng) */ } // namespace fp8 } // namespace migraphx -// define numeric limits for the new data type -#ifndef MIGRAPHX_JIT_USE_HIPRTC -namespace std { -inline bool isfinite(migraphx::fp8::float8 x) // NOLINT -{ - return x.is_inf(); -} - -inline bool isfinite(migraphx::fp8::float8 x) // NOLINT -{ - return x.is_inf(); -} - -inline bool isnan(migraphx::fp8::float8 x) // NOLINT -{ - return x.is_nan(); -} - -inline bool isnan(migraphx::fp8::float8 x) // NOLINT -{ - return x.is_nan(); -} - -template <> -class numeric_limits> - : public migraphx::fp8::numeric_limits> -{ -}; - -template <> -class numeric_limits> - : public migraphx::fp8::numeric_limits> -{ -}; - -template -struct common_type : std::common_type // NOLINT -{ -}; - -template -struct common_type : std::common_type // NOLINT -{ -}; - -template <> -struct common_type -{ - using type = float; -}; - -} // namespace std -#endif // ================================================================================================= #if defined(__clang__) #pragma clang diagnostic pop diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp index c999487c85b..e9407d1ef66 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp @@ -24,7 +24,7 @@ #ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP #define MIGRAPHX_GUARD_KERNELS_HIP_HPP -#ifndef MIGRAPHX_JIT_USE_HIPRTC +#ifndef MIGRAPHX_USE_HIPRTC #include #include #include diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp index 6575d5b2bf0..a31f3eace82 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp @@ -27,7 +27,7 @@ namespace migraphx { -#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_JIT_USE_HIPRTC) +#if defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS) and defined(MIGRAPHX_USE_HIPRTC) using int8_t = signed char; using uint8_t = unsigned char; using int16_t = signed short; @@ -36,7 +36,7 @@ using int32_t = signed int; using uint32_t = unsigned int; using int64_t = signed long long; using uint64_t = unsigned long long; -#elif defined(MIGRAPHX_JIT_USE_HIPRTC) +#elif defined(MIGRAPHX_USE_HIPRTC) using int8_t = __hip_int8_t; using uint8_t = __hip_uint8_t; using int16_t = __hip_int16_t; @@ -54,7 +54,7 @@ using int32_t = std::int32_t; using uint32_t = std::uint32_t; using int64_t = std::int64_t; using uint64_t = std::uint64_t; -#endif // MIGRAPHX_JIT_USE_HIPRTC +#endif // MIGRAPHX_USE_HIPRTC using index_int = uint32_t; using diff_int = int32_t; From d2c25a07250644e1e8c6efbc5fe9254fc0c30e5d Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 15:17:32 +0000 Subject: [PATCH 034/115] Remove FP8_Lowest/Max --- .../gpu/kernels/include/migraphx/kernels/float8.hpp | 12 ------------ .../kernels/include/migraphx/kernels/type_traits.hpp | 5 +++-- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 8345487e993..2801c0f2e5d 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -373,18 +373,6 @@ inline MIGRAPHX_HIP_DEVICE fp8e5m2 fabs(fp8e5m2 v) return v; } -template -MIGRAPHX_HIP_DEVICE constexpr T F8_Max() -{ - return T{0x7F, T::from_bits()}; -} - -template -MIGRAPHX_HIP_DEVICE constexpr T F8_Lowest() -{ - return T{0xFF, T::from_bits()}; -} - template <> class numeric_limits { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp index d1642bb1399..166bee2e57a 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp @@ -248,8 +248,9 @@ constexpr T numeric_max() return __FLT_MAX__; else if constexpr(is_same{}) return __FLT16_MAX__; + // TODO: Do it generically for all fp8 types else if constexpr(is_same{}) - return migraphx::fp8::F8_Max(); + return migraphx::fp8::numeric_limits::max(); else return 0; } @@ -265,7 +266,7 @@ constexpr T numeric_lowest() return -numeric_max() - 1; } else if constexpr(is_same{}) - return migraphx::fp8::F8_Lowest(); + return migraphx::fp8::numeric_limits::lowest(); else { return -numeric_max(); From 5da68df6fabc2f50dd315139be4ec9e4a986ba4e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 15:21:21 +0000 Subject: [PATCH 035/115] remove using for dtypes --- .../gpu/kernels/include/migraphx/kernels/float8.hpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 2801c0f2e5d..59c287970cb 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -30,18 +30,13 @@ #pragma clang diagnostic ignored "-Wc++20-extensions" #endif // __clang__ -#include - #define MIGRAPHX_HIP_DEVICE __device__ // We are clipping in down conversion by default #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 +#include #include -using uint8_t = migraphx::uint8_t; -using uint16_t = migraphx::uint16_t; -using uint32_t = migraphx::uint32_t; - #include namespace migraphx { From b36f72d3b21983e3fec87ac06a9c127895459fba Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 18:10:43 +0000 Subject: [PATCH 036/115] Update float8_impl --- .../include/migraphx/kernels/float8.hpp | 1 - .../include/migraphx/kernels/float8_impl.hpp | 279 ++++++++++-------- 2 files changed, 154 insertions(+), 126 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 59c287970cb..82c69906c4e 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -35,7 +35,6 @@ // We are clipping in down conversion by default #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 -#include #include #include diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp index e45eb832f17..142a93b8873 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp @@ -47,24 +47,28 @@ struct conditional namespace fp8 { namespace impl { -template -__device__ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) +template +__device__ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) { + constexpr bool is_float = true; + // half is not supported for now + constexpr bool is_half = false; + static_assert(Wm + We == 7, "Wm+We==7"); + static_assert(is_float or is_half, "Only float can be cast to f8"); - static_assert(wm + we == 7, "wm+we==7"); - - const int mfmt = (sizeof(T) == 4) ? 23 : 10; - typename migraphx::detail::conditional::type x; + const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10; + typename detail::conditional::type x; if constexpr(sizeof(T) == 4) - x = migraphx::bit_cast(_x); + x = migraphx::bit_cast(f_x); else - x = migraphx::bit_cast(_x); - - uint32_t head, mantissa; - int exponent, bias; - uint32_t sign; + x = migraphx::bit_cast(f_x); + uint32_t head = 0; + uint32_t mantissa = 0; + int exponent = 0; + uint32_t bias = 0; + uint32_t sign = 0; if constexpr(sizeof(T) == 4) { head = x & 0xFF800000; @@ -82,76 +86,79 @@ __device__ constexpr uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) bias = 15; } - uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm); + uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1)); + + // Calcualte maximum singed value FLT_MAX, FLT_MIN + uint32_t signed_max = signed_all_ones; + if(not NegativeZeroNan) + signed_max = (Wm == 2) ? (signed_max - 4) : (signed_max - 1); // Deal with inf and NaNs - if(negative_zero_nan) + if(NegativeZeroNan) // For the FNUZ cases, it is simple just return NaNs { - if(sizeof(T) == 4) - { - if((x & 0x7F800000) == 0x7F800000) - return 0x80; - } - else - { - // if(__hisinf(x) || __hisnan(x)) - if((x & 0x7C00) == 0x7C00) - return 0x80; - } + if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or + (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) + return 0x80; } else { - if(sizeof(T) == 4) + // calculate most common NaN mantissa for FP8, which is all Ones in binary + uint32_t nan_mantissa = 1; + for(auto i = 1; i < Wm; ++i) { - if((x & 0x7F800000) == 0x7F800000) - return signed_inf + (mantissa != 0 ? 1 : 0); + nan_mantissa |= (nan_mantissa << 1); } - else + if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or + (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) { - if((x & 0x7C00) == 0x7C00) - return signed_inf + (mantissa != 0 ? 1 : 0); + // infinity + if(mantissa == 0) + { + if(sign == 0) + return (Wm == 2) ? 0x7B : 0x7E; + else + return (Wm == 2) ? 0xFB : 0xFE; + } + else // NaNs + return signed_inf + nan_mantissa; } } // handle positive zero if(x == 0) return 0; // handle negative zero - if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) + else if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) { - if(negative_zero_nan) - { - return 0; - } - else - { - return 0x80; - } + return NegativeZeroNan ? 0 : 0x80; // For FNUZ types neg zero is just positive zero } - // First need to check if it is normal or denorm as there is a difference of implict 1 - // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift - // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for - // RNE, no need to add rng. Then probably need to check whether there is carry and adjust - // exponent and mantissa again + /* First need to check if it is normal or denorm as there is a difference of implict 1 + Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + RNE, no need to add rng. Then probably need to check whether there is carry and adjust + exponent and mantissa again*/ // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits - const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0); const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal - // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) - // f8_exponent is the converted f8 exponent with bias encoding - // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, - // the difference needs to be adjusted and mantissa shifted - int act_exponent, f8_exponent, exponent_diff; + /* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + f8_exponent is the converted f8 exponent with bias encoding + exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + the difference needs to be adjusted and mantissa shifted*/ + int act_exponent = 0; + int f8_exponent = 0; + int exponent_diff = 0; - if(exponent == 0) + if(exponent == 0 and mantissa != 0) { // fp32/fp16 is in denormal. /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 -here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has -exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in -fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers -where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In -this case, the fp16 mantissa should be shift left by 1 */ - act_exponent = exponent - bias + 1; + here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal + has exponent bias 15 while bf8 with FNUZ has exponent bias 16. It means that there are some + numbers in fp16 denormal but they are bf8 (FNUZ) normals - smallest bf8 (FNUZ) normal is + 2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 + are bf8 (FNUZ) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = 1 - bias; exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal } @@ -161,10 +168,10 @@ this case, the fp16 mantissa should be shift left by 1 */ if(act_exponent <= f8_denormal_act_exponent) { /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. - For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 - actual exponent is -7, it is actually larger due to the implict 1, - Therefore it needs to be adjust to -6 and mantissa shift right by 1. - So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */ exponent_diff = f8_denormal_act_exponent - act_exponent; } else @@ -176,13 +183,15 @@ this case, the fp16 mantissa should be shift left by 1 */ mantissa += (1 << mfmt); // Add the implicit 1 into mantissa } - bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == - (1 << (mfmt - wm + exponent_diff - 1)); + // need to know whether the number is right in the middle of two adjacent fp8 numbers. use max + // value of 31 to avoid undefined behaviour + bool midpoint = (mantissa & ((1u << (mfmt - Wm + exponent_diff)) - 1)) == + (1u << (mfmt - Wm + exponent_diff - 1)); /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we - shift right as shift right could rip off some residual part and make something not midpoint look - like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than - midpoint, but after shift right by 4 bits, it would look like midpoint. -*/ + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. + */ if(exponent_diff > 0) mantissa >>= exponent_diff; @@ -194,114 +203,134 @@ this case, the fp16 mantissa should be shift left by 1 */ (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); // Now we have the exponent and mantissa adjusted - uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + uint32_t drop_mask = (1 << (mfmt - Wm)) - 1; bool odd = - mantissa & (1 << (mfmt - wm)); // if the least significant bit that is not truncated is 1 + mantissa & (1 << (mfmt - Wm)); // if the least significant bit that is not truncated is 1 + /* + This part is doing rounding by adding mantissa part that is going to get dropped. + e.g. if the dropped part for less than 0.5 than it would round down. + if the dropped part is more than 0.5 then it would round up by rolling carry to LSB of retained + mantissa. + For the mid point when bit pattern is like this for Odd: `xy1:10000000` for Odd and + `xy0:10000000` for the Even. where `:` is delimiter for dropped v/s retained part. + For the odd case : + this will add xy1:10000000 + 000:10000000 which would roll over carry to LSB of retained + part making it RNE. + For the even case : this will add xy0:10000000 + 000:01111111 which would + round down and keep number Even + */ mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; // Now we deal with overflow - if(f8_exponent == 0) + if(f8_exponent == 0 and ((1 << mfmt) & mantissa)) { - if((1 << mfmt) & mantissa) - { - f8_exponent = 1; // denormal overflow to become normal, promote exponent - } + f8_exponent = 1; // denormal overflow to become normal, promote exponent } - else + else if((1 << (mfmt + 1)) & mantissa) { - if((1 << (mfmt + 1)) & mantissa) - { - mantissa >>= 1; - f8_exponent++; - } + mantissa >>= 1; + f8_exponent++; } - mantissa >>= (mfmt - wm); + mantissa >>= (mfmt - Wm); // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + // for e5m2 case, max_exp is 14, since exp = 15 is reserved for Infs and Nans + const int max_exp = (1 << We) - ((NegativeZeroNan or Wm == 3) ? 1 : 2); if(f8_exponent > max_exp) { - if(clip) - { - mantissa = (1 << wm) - 1; - f8_exponent = max_exp; - } + if(Clip) + return signed_max; else { - return signed_inf; + // https://onnx.ai/onnx/technical/float8.html#cast + if(NegativeZeroNan) + return 0x80; + else + return (Wm == 2) ? signed_inf : signed_all_ones; } } - if(f8_exponent == 0 && mantissa == 0) - return negative_zero_nan ? 0 : (sign << 7); - mantissa &= (1 << wm) - 1; - return (sign << 7) | (f8_exponent << wm) | mantissa; + if(f8_exponent == 0 and mantissa == 0) + return NegativeZeroNan ? 0 : (sign << 7); + mantissa &= (1 << Wm) - 1; + return (sign << 7) | (f8_exponent << Wm) | mantissa; } -template +template __device__ constexpr T cast_from_f8(uint8_t x) { - constexpr int weo = 8; - constexpr int wmo = 23; + // half is not supported for now + constexpr bool is_half = false; + constexpr bool is_float = true; + static_assert(is_float or is_half, "Only float are supported"); - T fInf, fNegInf, fNaN, fNeg0; - uint32_t ifInf = 0x7F800000; - uint32_t ifNegInf = 0xFF800000; - uint32_t ifNaN = 0x7F800001; - uint32_t ifNeg0 = 0x80000000; - // TODO: need to change T for half but right now it would never called with half - fInf = migraphx::bit_cast(ifInf); - fNegInf = migraphx::bit_cast(ifNegInf); - fNaN = migraphx::bit_cast(ifNaN); - fNeg0 = migraphx::bit_cast(ifNeg0); + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + // NOLINTNEXTLINE + T f_inf, f_neg_inf, f_nan, f_neg0; + + if constexpr(is_float) + { + const uint32_t if_inf = 0x7F800000; + const uint32_t if_neg_inf = 0xFF800000; + const uint32_t if_nan = 0x7F800001; + const uint32_t if_neg0 = 0x80000000; + f_inf = migraphx::bit_cast(if_inf); + f_neg_inf = migraphx::bit_cast(if_neg_inf); + f_nan = migraphx::bit_cast(if_nan); + f_neg0 = migraphx::bit_cast(if_neg0); + } if(x == 0) return 0; - uint32_t sign = x >> 7; - uint32_t mantissa = x & ((1 << wm) - 1); - int exponent = (x & 0x7F) >> wm; - if(negative_zero_nan) + uint32_t sign = x >> 7; // NOLINT + uint32_t mantissa = x & ((1 << Wm) - 1); // NOLINT + int exponent = (x & 0x7F) >> Wm; // NOLINT + if(NegativeZeroNan) { if(x == 0x80) - return fNaN; + return f_nan; } else { if(x == 0x80) - return fNeg0; - if(exponent == ((1 << we) - 1)) - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + return f_neg0; + if(exponent == ((1 << We) - 1) and Wm == 2) // NOLINT + return (mantissa == 0) ? (sign ? f_neg_inf : f_inf) : f_nan; + else if(Wm == 3 and (x == 0x7F or x == 0xFF)) + return f_nan; } - typename migraphx::detail::conditional::type retval; + typename detail::conditional::type retval; - const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT // subnormal input if(exponent == 0) { // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + __builtin_clz(mantissa) - (32 - wm); - mantissa <<= sh; + int sh = 1 + __builtin_clz(mantissa) - (32 - Wm); + mantissa <<= sh; // NOLINT exponent += 1 - sh; - mantissa &= ((1 << wm) - 1); + mantissa &= ((1 << Wm) - 1); // NOLINT } exponent += exp_low_cutoff - 1; - mantissa <<= wmo - wm; + mantissa <<= wmo - Wm; // NOLINT - // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + // subnormal output (occurs when T=half, We=5, negative_zero_nan=true) if(exponent <= 0) { - mantissa |= 1 << wmo; - mantissa >>= 1 - exponent; + mantissa |= 1 << wmo; // NOLINT + mantissa >>= 1 - exponent; // NOLINT exponent = 0; } if(sizeof(T) == 2) - retval = (sign << 15) | (exponent << 10) | mantissa; + retval = (sign << 15) | (exponent << 10) | mantissa; // NOLINT else - retval = (sign << 31) | (exponent << 23) | mantissa; + retval = (sign << 31) | (exponent << 23) | mantissa; // NOLINT return migraphx::bit_cast(retval); } } // namespace impl From 85ba819bf323a8e3e4f01a9de1ee6af48336f8da Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 20:38:28 +0000 Subject: [PATCH 037/115] constructor from float works with constexpr --- .../include/migraphx/kernels/float8.hpp | 82 ++++++++----------- 1 file changed, 34 insertions(+), 48 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 82c69906c4e..7ac94bf2d06 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -75,7 +75,7 @@ struct float8 // device specific optimized F8 down-conversion code template - static MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0) + static constexpr MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0) { uint8_t i8data; union @@ -135,7 +135,7 @@ struct float8 #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // NOTE: ON-DEVICE... always optimal bias - explicit MIGRAPHX_HIP_DEVICE + explicit constexpr MIGRAPHX_HIP_DEVICE float8(float v, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, uint32_t rng = 0) @@ -176,7 +176,7 @@ struct float8 data = migraphx::fp8::impl:: cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); -#endif // rocblas_F8_downcast_clipping} +#endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING} } } @@ -314,58 +314,44 @@ struct float8 } }; -// NOLINTNEXTLINE -#define MIGRAPHX_FP8_BINARY_OP(binary_op, U) \ - template \ - inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const migraphx::fp8::float8& lhs, \ - const migraphx::fp8::float8& rhs) \ - { \ - return U(static_cast(lhs) binary_op static_cast(rhs)); \ - } - -// TODO: these should return floats -MIGRAPHX_FP8_BINARY_OP(*, migraphx::fp8::float8) -MIGRAPHX_FP8_BINARY_OP(-, migraphx::fp8::float8) -MIGRAPHX_FP8_BINARY_OP(/, migraphx::fp8::float8) -MIGRAPHX_FP8_BINARY_OP(+, migraphx::fp8::float8) -// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects. -MIGRAPHX_FP8_BINARY_OP(==, bool) -MIGRAPHX_FP8_BINARY_OP(>=, bool) -MIGRAPHX_FP8_BINARY_OP(<=, bool) -MIGRAPHX_FP8_BINARY_OP(>, bool) -MIGRAPHX_FP8_BINARY_OP(<, bool) -MIGRAPHX_FP8_BINARY_OP(!=, bool) - // https://onnx.ai/onnx/technical/float8.html using fp8e4m3fn = float8; using fp8e5m2 = float8; using fp8e4m3fnuz = float8; using fp8e5m2fnuz = float8; -; - -inline MIGRAPHX_HIP_DEVICE fp8e4m3fnuz fabs(fp8e4m3fnuz v) -{ - v.data = v.data & 0x7f; - return v; -} - -inline MIGRAPHX_HIP_DEVICE fp8e4m3fn fabs(fp8e4m3fn v) -{ - v.data = v.data & 0x7f; - return v; -} -inline MIGRAPHX_HIP_DEVICE fp8e5m2fnuz fabs(fp8e5m2fnuz v) -{ - v.data = v.data & 0x7f; - return v; -} +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \ + inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const T& lhs, const T& rhs) \ + { \ + return U(static_cast(lhs) binary_op static_cast(rhs)); \ + } -inline MIGRAPHX_HIP_DEVICE fp8e5m2 fabs(fp8e5m2 v) -{ - v.data = v.data & 0x7f; - return v; -} +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_UNARY_OP(unary_op, T) \ + inline constexpr MIGRAPHX_HIP_DEVICE T unary_op(T v) \ + { \ + v.data = v.data & 0x7f; \ + return v; \ + } + +#define MIGRAPHX_FP8_GEN_OP_OVERLOADS(T) \ + MIGRAPHX_FP8_BINARY_OP(*, T, T) \ + MIGRAPHX_FP8_BINARY_OP(-, T, T) \ + MIGRAPHX_FP8_BINARY_OP(/, T, T) \ + MIGRAPHX_FP8_BINARY_OP(+, T, T) \ + MIGRAPHX_FP8_BINARY_OP(==, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(>, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(<, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \ + MIGRAPHX_FP8_UNARY_OP(fabs, T) + +MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2) +MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz) +MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e4m3fn) +MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e4m3fnuz) template <> class numeric_limits From aed1922b9f414be7341a8a45e31052951736a4df Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 21:31:01 +0000 Subject: [PATCH 038/115] Remove unnecessary pragmas --- .../include/migraphx/kernels/float8.hpp | 60 ++----------------- 1 file changed, 6 insertions(+), 54 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 7ac94bf2d06..87c6fb2a39f 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -24,9 +24,6 @@ #define MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP #if defined(__clang__) #pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wold-style-cast" -#pragma clang diagnostic ignored "-Wfloat-equal" -#pragma clang diagnostic ignored "-Wmacro-redefined" #pragma clang diagnostic ignored "-Wc++20-extensions" #endif // __clang__ @@ -268,7 +265,7 @@ struct float8 } } -#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \ +#define MIGRAPHX_FP8_SHORT_UNARY_OP(unary_op, binary_op) \ constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float8& rhs) \ { \ const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ @@ -282,10 +279,10 @@ struct float8 return *this; \ } - MIGRAPHX_FP8_UNARY_OP(*=, *) - MIGRAPHX_FP8_UNARY_OP(-=, -) - MIGRAPHX_FP8_UNARY_OP(+=, +) - MIGRAPHX_FP8_UNARY_OP(/=, /) + MIGRAPHX_FP8_SHORT_UNARY_OP(*=, *) + MIGRAPHX_FP8_SHORT_UNARY_OP(-=, -) + MIGRAPHX_FP8_SHORT_UNARY_OP(+=, +) + MIGRAPHX_FP8_SHORT_UNARY_OP(/=, /) inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default; inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) = default; @@ -483,52 +480,7 @@ class numeric_limits return fp8e5m2(0x7C, fp8e5m2::from_bits()); } }; -/* -// Use h/w intrinsic and optimized version when __gfx940__ -template {}) && - (migraphx::is_same{} || - migraphx::is_same{})), - int>::type = 0> -inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng) -{ -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) - // NOTE: we are directly calling cast_to_f8_from_f32 instead of constructor to optimize - // away one runtime branch - T val; - if(migraphx::is_same::value) - val.data = migraphx_f8::cast_to_f8_from_f32(float(a), rng); - else - val.data = migraphx_bf8::cast_to_bf8_from_f32(float(a), rng); - return val; -#else // non gfx940 - return T(float(a), - stochastic_rounding ? migraphx::fp8::rounding_mode::stochastic - : migraphx::fp8::rounding_mode::standard, - rng); -#endif // __gfx940__ -} - -// NOTE NOTE: The above code is good if we don't consider HIP-GEMM code and only consider -// the quantization However, if we need HIP-GEMM for fall-back, we would need explicit_cast -// handles Tacc=f32 to To=f16/bf16 conversion -template {}) && - !(migraphx::is_same{} || - migraphx::is_same{})), - int>::type = 0> -inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng) -{ - // the return type is not a F8 types, no SR for those types - // not sure if we have direct conversion, so converting to float first - // no effect if the input type is float - return T(float(a)); -} -*/ + } // namespace fp8 } // namespace migraphx // ================================================================================================= From f975c63362d9220547625f789449609ccfe522d6 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 21:32:32 +0000 Subject: [PATCH 039/115] Remove clang diagnostics --- .../gpu/kernels/include/migraphx/kernels/float8.hpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 87c6fb2a39f..ae172773245 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -22,10 +22,6 @@ #ifndef MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP #define MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wc++20-extensions" -#endif // __clang__ #define MIGRAPHX_HIP_DEVICE __device__ @@ -74,7 +70,7 @@ struct float8 template static constexpr MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0) { - uint8_t i8data; + uint8_t i8data = 0x00; union { float fval; @@ -484,7 +480,4 @@ class numeric_limits } // namespace fp8 } // namespace migraphx // ================================================================================================= -#if defined(__clang__) -#pragma clang diagnostic pop -#endif #endif // MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP From 32033d856c538daed68007af54d2425fbefefd53 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 21:41:42 +0000 Subject: [PATCH 040/115] Add back floatequal --- .../gpu/kernels/include/migraphx/kernels/float8.hpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index ae172773245..6f4219a627c 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -22,7 +22,10 @@ #ifndef MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP #define MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP - +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wfloat-equal" +#endif // __clang__ #define MIGRAPHX_HIP_DEVICE __device__ // We are clipping in down conversion by default @@ -480,4 +483,8 @@ class numeric_limits } // namespace fp8 } // namespace migraphx // ================================================================================================= +#if defined(__clang__) +#pragma clang diagnostic pop +#endif // __clang__ + #endif // MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP From e88d46a6d8b6b55ccc5f8282d328176610bbaa23 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 21:49:14 +0000 Subject: [PATCH 041/115] disable DPP For FP8 --- src/targets/gpu/jit/reduce.cpp | 36 +++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 1e018d2633e..d3a51153fc3 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -146,6 +146,7 @@ struct simple_reduce_compiler : compiler vectorize vec{}; auto nelements = options.virtual_inputs.back().elements(); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); + if(algo == "block") { // Vectorize if the axis is a reduction axis @@ -169,13 +170,20 @@ struct simple_reduce_compiler : compiler options.kernel_name = "reduce_kernel"; std::string identity = "[](auto x) { return x; }"; auto src = interpolate_string(simple_reduce_kernel, - {{"reduction", v.at("reduction").to()}, - {"init", v.get("init", std::string{"0"})}, - {"read", v.get("read", identity)}, - {"write", v.get("write", identity)}, - {"algo", algo}, - {"transformers", make_transformer_args(vec)}, - {"preamble", v.get("preamble", std::string{})}}); + {{"reduction", v.at("reduction").to()}, + {"init", v.get("init", std::string{"0"})}, + {"read", v.get("read", identity)}, + {"write", v.get("write", identity)}, + {"algo", algo}, + {"transformers", make_transformer_args(vec)}, + {"preamble", v.get("preamble", std::string{})}}); + // disable DPP for FP8 for now,, TODO: need to disable for Any FP8 types + if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { + return s.type() == migraphx::shape::fp8e4m3fnuz_type; + })) + { + options.params += "-DMIGRAPHX_HAS_DPP=0 "; + } options.params += "-Wno-float-equal"; return compile_hip_code_object(src, options); } @@ -266,13 +274,13 @@ struct fused_reduce_compiler : compiler auto src = interpolate_string( fused_reduce_kernel, {{"kernel", options.kernel_name}, - {"params", enum_params(inputs.size(), "void * private_p")}, - {"args", enum_params(inputs.size(), "private_p")}, - {"algo", algo}, - {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, - {"lambda", v.at("lambda").to()}, - {"transformers", make_transformer_args(vec)}, - {"preamble", v.get("preamble", std::string{})}}); + {"params", enum_params(inputs.size(), "void * private_p")}, + {"args", enum_params(inputs.size(), "private_p")}, + {"algo", algo}, + {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, + {"lambda", v.at("lambda").to()}, + {"transformers", make_transformer_args(vec)}, + {"preamble", v.get("preamble", std::string{})}}); options.params += "-Wno-float-equal"; return compile_hip_code_object(src, options); } From 60dd1f46e3dd3e4da004f4f6f2225c458bc6dc2c Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 21:55:10 +0000 Subject: [PATCH 042/115] formatting --- src/targets/gpu/compile_gen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index df3c59cc2f7..148e0771370 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -315,7 +315,7 @@ std::string generate_reduce(const module& m, const std::string& name) std::transform( params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; }); return interpolate_string(inner_template, - {{"inner", inner_name}, + {{"inner", inner_name}, {"params", join_strings(params, ", ")}, {"args", join_strings(args, ", ")}, {"call", call_function}}); From ef425d04bfa193f25bc4b61018ec438149b6cf92 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 21:57:00 +0000 Subject: [PATCH 043/115] revert unwanted changes --- src/targets/gpu/compile_hip.cpp | 15 +++++++++++++-- .../kernels/include/migraphx/kernels/types.hpp | 1 + .../include/migraphx/kernels/vectorize.hpp | 1 - 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/compile_hip.cpp b/src/targets/gpu/compile_hip.cpp index c3e101a3ae1..e58b681b563 100644 --- a/src/targets/gpu/compile_hip.cpp +++ b/src/targets/gpu/compile_hip.cpp @@ -251,10 +251,21 @@ compile_hip_src(const std::vector& srcs, std::string params, const std std::cout << std::string(src.content) << std::endl; } } + auto fname = fs::path{"migraphx-hiprtc-driver"}; +#ifdef _WIN32 + fname.replace_extension(".exe"); +#endif auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc); - auto driver = p.parent_path().parent_path() / "bin" / "migraphx-hiprtc-driver"; + auto driver = p.parent_path() / fname; + + bool found = fs::exists(driver); + if(not found) + { + driver = p.parent_path().parent_path() / "bin" / fname; + found = fs::exists(driver); + } - if(fs::exists(driver)) + if(found) { value v; v["srcs"] = to_value(hsrcs); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp index a31f3eace82..4f71d1985a1 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp @@ -23,6 +23,7 @@ */ #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP + #include namespace migraphx { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp index b66f88b7383..b456b5c6e45 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp @@ -24,7 +24,6 @@ #ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP #define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP -#include #include #include From bd0ae5fa0ee05b2b72ebf61a13fd7c5492393791 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 17 Nov 2023 23:27:41 +0000 Subject: [PATCH 044/115] add some more tests --- test/verify/test_concat_axis_0.cpp | 15 +++++++--- test/verify/test_gather.cpp | 14 +++++---- test/verify/test_gathernd_default.cpp | 9 ++++-- test/verify/test_isnan_float.cpp | 7 ++++- test/verify/test_isnan_half.cpp | 43 --------------------------- 5 files changed, 33 insertions(+), 55 deletions(-) delete mode 100644 test/verify/test_isnan_half.cpp diff --git a/test/verify/test_concat_axis_0.cpp b/test/verify/test_concat_axis_0.cpp index 25e2301adcc..e944df7caa1 100644 --- a/test/verify/test_concat_axis_0.cpp +++ b/test/verify/test_concat_axis_0.cpp @@ -27,16 +27,18 @@ #include #include -struct test_concat_axis_0 : verify_program +template + +struct test_concat_axis_0 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); int axis = 0; - migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; - migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; - migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; + migraphx::shape s0{DType, {2, 2}}; + migraphx::shape s1{DType, {3, 2}}; + migraphx::shape s2{DType, {1, 2}}; auto l0 = mm->add_parameter("x", s0); auto l1 = mm->add_parameter("y", s1); auto l2 = mm->add_parameter("z", s2); @@ -44,3 +46,8 @@ struct test_concat_axis_0 : verify_program return p; } }; + +template struct test_concat_axis_0; +template struct test_concat_axis_0; +template struct test_concat_axis_0; +template struct test_concat_axis_0; diff --git a/test/verify/test_gather.cpp b/test/verify/test_gather.cpp index 1afdceca4f9..20c5a047fe6 100644 --- a/test/verify/test_gather.cpp +++ b/test/verify/test_gather.cpp @@ -27,14 +27,14 @@ #include #include -template -struct test_gather : verify_program> +template +struct test_gather : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {3, 3}}; + migraphx::shape s{DType, {3, 3}}; migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; std::vector indices{1, 2, 2, 1}; auto a0 = mm->add_parameter("data", s); @@ -46,6 +46,10 @@ struct test_gather : verify_program> }; // Standard gather test -template struct test_gather<0>; +template struct test_gather<0, migraphx::shape::float_type>; +template struct test_gather<0, migraphx::shape::half_type>; +template struct test_gather<0, migraphx::shape::fp8e4m3fnuz_type>; // Test Negative axis -template struct test_gather<-2>; +template struct test_gather<-2, migraphx::shape::float_type>; +template struct test_gather<-2, migraphx::shape::half_type>; +template struct test_gather<-2, migraphx::shape::fp8e4m3fnuz_type>; diff --git a/test/verify/test_gathernd_default.cpp b/test/verify/test_gathernd_default.cpp index d4d48251c8a..e28cba3b54e 100644 --- a/test/verify/test_gathernd_default.cpp +++ b/test/verify/test_gathernd_default.cpp @@ -26,13 +26,14 @@ #include #include -struct test_gathernd_default : verify_program +template +struct test_gathernd_default : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; + migraphx::shape ds{DType, {2, 2}}; migraphx::shape is{migraphx::shape::int64_type, {2, 2}}; std::vector indices{0, 0, 1, 1}; auto a0 = mm->add_parameter("data", ds); @@ -41,3 +42,7 @@ struct test_gathernd_default : verify_program return p; } }; + +template struct test_gathernd_default; +template struct test_gathernd_default; +template struct test_gathernd_default; diff --git a/test/verify/test_isnan_float.cpp b/test/verify/test_isnan_float.cpp index 86a164dd142..cc1c1329da8 100644 --- a/test/verify/test_isnan_float.cpp +++ b/test/verify/test_isnan_float.cpp @@ -27,7 +27,8 @@ #include #include -struct test_isnan_float : verify_program +template +struct test_isnan : verify_program> { migraphx::program create_program() const { @@ -40,3 +41,7 @@ struct test_isnan_float : verify_program return p; } }; + +template struct test_isnan; +template struct test_isnan; +template struct test_isnan; diff --git a/test/verify/test_isnan_half.cpp b/test/verify/test_isnan_half.cpp deleted file mode 100644 index 1576c555c81..00000000000 --- a/test/verify/test_isnan_half.cpp +++ /dev/null @@ -1,43 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#include -#include "verify_program.hpp" -#include -#include -#include -#include - -struct test_isnan_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2}}); - auto l0 = mm->add_literal(std::numeric_limits::quiet_NaN()); - x = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, l0); - mm->add_instruction(migraphx::make_op("isnan"), x); - return p; - } -}; From 91cc9c7cf6bf3118ceaaae21d6f065d7348d2aef Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sat, 18 Nov 2023 00:35:09 +0000 Subject: [PATCH 045/115] Add math and reduce tests --- .../include/migraphx/kernels/reduce.hpp | 11 +++++------ test/verify/test_acosh.cpp | 9 +++++++-- test/verify/test_asin.cpp | 9 +++++++-- test/verify/test_asinh.cpp | 9 +++++++-- test/verify/test_atan.cpp | 9 +++++++-- test/verify/test_atanh.cpp | 9 +++++++-- test/verify/test_ceil.cpp | 9 +++++++-- test/verify/test_cos.cpp | 9 +++++++-- test/verify/test_cosh.cpp | 9 +++++++-- test/verify/test_erf.cpp | 9 +++++++-- test/verify/test_exp.cpp | 9 +++++++-- test/verify/test_floor.cpp | 9 +++++++-- test/verify/test_fmod_mod.cpp | 1 + test/verify/test_gathernd_default.cpp | 2 +- .../{test_isnan_float.cpp => test_isnan.cpp} | 0 test/verify/test_layernorm.cpp | 12 ++++++++++++ test/verify/test_log.cpp | 9 +++++++-- test/verify/test_min_max.cpp | 2 ++ test/verify/test_nearbyint.cpp | 2 ++ test/verify/test_pad.cpp | 10 ++++++++-- test/verify/test_pow.cpp | 1 + test/verify/test_reduce_add.cpp | 9 +++++++-- test/verify/test_reduce_mean_nhwc.cpp | 12 ++++++++---- test/verify/test_reduce_op_large.cpp | 16 ++++++++++++++++ test/verify/test_reduce_op_small.cpp | 16 ++++++++++++++++ test/verify/test_roialign.cpp | 19 ++++++++++++------- test/verify/test_rsqrt.cpp | 2 ++ test/verify/test_scatternd.cpp | 13 +++++++++---- test/verify/test_sin.cpp | 9 +++++++-- test/verify/test_sinh.cpp | 9 +++++++-- test/verify/test_softmax.cpp | 4 ++++ test/verify/test_sqrt.cpp | 9 +++++++-- test/verify/test_tan.cpp | 9 +++++++-- test/verify/test_tanh.cpp | 9 +++++++-- test/verify/test_where.cpp | 7 ++++++- 35 files changed, 231 insertions(+), 61 deletions(-) rename test/verify/{test_isnan_float.cpp => test_isnan.cpp} (100%) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index eb870f39878..a106773d1dc 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -106,7 +106,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) #endif using type = decltype(index::invoke_loop(f, 0, _c<0>)); __shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; - type x = init; + type x = type(init); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); dpp_reduce(x, op); @@ -117,7 +117,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) } __syncthreads(); - type y = init; + type y = type(init); for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++) { y = op(y, buffer[i]); @@ -244,9 +244,8 @@ struct reducer_base { auto&& derived = static_cast(*this); auto t = derived.slice(x); - return make_storage_access([=](auto i, auto...) -> auto& { - return t[i]; - }); + return make_storage_access( + [=](auto i, auto...) -> auto& { return t[i]; }); } } @@ -482,7 +481,7 @@ struct lane __device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const { using type = remove_reference_t))>; - type r = init; + type r = type(init); for(index_int j = 0; j < n; j++) { r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...)); diff --git a/test/verify/test_acosh.cpp b/test/verify/test_acosh.cpp index 4e2c39a3103..9acea66cc58 100644 --- a/test/verify/test_acosh.cpp +++ b/test/verify/test_acosh.cpp @@ -27,13 +27,14 @@ #include #include -struct test_acosh : verify_program +template +struct test_acosh : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {16}}; + migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); auto min_val = mm->add_literal(1.1f); auto max_val = mm->add_literal(100.0f); @@ -46,3 +47,7 @@ struct test_acosh : verify_program return p; } }; + +template struct test_acosh; +// template struct test_acosh; +// template struct test_acosh; diff --git a/test/verify/test_asin.cpp b/test/verify/test_asin.cpp index 615cc34cc55..84d609ab63d 100644 --- a/test/verify/test_asin.cpp +++ b/test/verify/test_asin.cpp @@ -27,15 +27,20 @@ #include #include -struct test_asin : verify_program +template +struct test_asin : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {16}}; + migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("asin"), x); return p; } }; + +template struct test_asin; +template struct test_asin; +template struct test_asin; diff --git a/test/verify/test_asinh.cpp b/test/verify/test_asinh.cpp index a06d4d34b53..1e5ddd6b6ab 100644 --- a/test/verify/test_asinh.cpp +++ b/test/verify/test_asinh.cpp @@ -27,15 +27,20 @@ #include #include -struct test_asinh : verify_program +template +struct test_asinh : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {16}}; + migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("asinh"), x); return p; } }; + +template struct test_asinh; +template struct test_asinh; +template struct test_asinh; diff --git a/test/verify/test_atan.cpp b/test/verify/test_atan.cpp index a0915fc43f6..ea160c7fee7 100644 --- a/test/verify/test_atan.cpp +++ b/test/verify/test_atan.cpp @@ -27,15 +27,20 @@ #include #include -struct test_atan : verify_program +template +struct test_atan : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {16}}; + migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("atan"), x); return p; } }; + +template struct test_atan; +template struct test_atan; +template struct test_atan; diff --git a/test/verify/test_atanh.cpp b/test/verify/test_atanh.cpp index e61f2688052..ed842aa7d6c 100644 --- a/test/verify/test_atanh.cpp +++ b/test/verify/test_atanh.cpp @@ -27,13 +27,14 @@ #include #include -struct test_atanh : verify_program +template +struct test_atanh : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {16}}; + migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); auto min_val = mm->add_literal(-0.95f); auto max_val = mm->add_literal(0.95f); @@ -46,3 +47,7 @@ struct test_atanh : verify_program return p; } }; + +template struct test_atanh; +// template struct test_atanh; +// template struct test_atanh; diff --git a/test/verify/test_ceil.cpp b/test/verify/test_ceil.cpp index 0d91cbcf0b9..b354df802fb 100644 --- a/test/verify/test_ceil.cpp +++ b/test/verify/test_ceil.cpp @@ -27,16 +27,21 @@ #include #include -struct test_ceil : verify_program +template +struct test_ceil : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; + migraphx::shape s{DType, {2, 3, 4, 6}}; auto param = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("ceil"), param); return p; }; }; + +template struct test_ceil; +template struct test_ceil; +template struct test_ceil; diff --git a/test/verify/test_cos.cpp b/test/verify/test_cos.cpp index f4d61b30060..726ccc72306 100644 --- a/test/verify/test_cos.cpp +++ b/test/verify/test_cos.cpp @@ -27,15 +27,20 @@ #include #include -struct test_cos : verify_program +template +struct test_cos : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {8}}; + migraphx::shape s{DType, {8}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("cos"), x); return p; } }; + +template struct test_cos; +template struct test_cos; +template struct test_cos; diff --git a/test/verify/test_cosh.cpp b/test/verify/test_cosh.cpp index 9721c61f9fe..6d180da0580 100644 --- a/test/verify/test_cosh.cpp +++ b/test/verify/test_cosh.cpp @@ -27,15 +27,20 @@ #include #include -struct test_cosh : verify_program +template +struct test_cosh : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {16}}; + migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("cosh"), x); return p; } }; + +template struct test_cosh; +template struct test_cosh; +template struct test_cosh; diff --git a/test/verify/test_erf.cpp b/test/verify/test_erf.cpp index 1581e18a3f4..3907a03404c 100644 --- a/test/verify/test_erf.cpp +++ b/test/verify/test_erf.cpp @@ -27,15 +27,20 @@ #include #include -struct test_erf : verify_program +template +struct test_erf : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; + migraphx::shape s{DType, {2, 3, 4, 6}}; auto param = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("erf"), param); return p; } }; + +template struct test_erf; +template struct test_erf; +template struct test_erf; diff --git a/test/verify/test_exp.cpp b/test/verify/test_exp.cpp index a699734bfb4..d61cbd4224d 100644 --- a/test/verify/test_exp.cpp +++ b/test/verify/test_exp.cpp @@ -27,15 +27,20 @@ #include #include -struct test_exp : verify_program +template +struct test_exp : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {6}}; + migraphx::shape s{DType, {6}}; auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s)); mm->add_instruction(migraphx::make_op("exp"), x); return p; } }; + +template struct test_exp; +template struct test_exp; +template struct test_exp; diff --git a/test/verify/test_floor.cpp b/test/verify/test_floor.cpp index 662c891c080..6029bc16b98 100644 --- a/test/verify/test_floor.cpp +++ b/test/verify/test_floor.cpp @@ -27,16 +27,21 @@ #include #include -struct test_floor : verify_program +template +struct test_floor : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; + migraphx::shape s{DType, {2, 3, 4, 6}}; auto param = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("floor"), param); return p; }; }; + +template struct test_floor; +template struct test_floor; +template struct test_floor; diff --git a/test/verify/test_fmod_mod.cpp b/test/verify/test_fmod_mod.cpp index 802cba73350..35a1bc383ee 100644 --- a/test/verify/test_fmod_mod.cpp +++ b/test/verify/test_fmod_mod.cpp @@ -71,3 +71,4 @@ struct test_mod : verify_program return p; } }; +// TODO: check if requires FP8 test \ No newline at end of file diff --git a/test/verify/test_gathernd_default.cpp b/test/verify/test_gathernd_default.cpp index e28cba3b54e..8e2c4458903 100644 --- a/test/verify/test_gathernd_default.cpp +++ b/test/verify/test_gathernd_default.cpp @@ -27,7 +27,7 @@ #include template -struct test_gathernd_default : verify_program +struct test_gathernd_default : verify_program> { migraphx::program create_program() const { diff --git a/test/verify/test_isnan_float.cpp b/test/verify/test_isnan.cpp similarity index 100% rename from test/verify/test_isnan_float.cpp rename to test/verify/test_isnan.cpp diff --git a/test/verify/test_layernorm.cpp b/test/verify/test_layernorm.cpp index 8bc54bc2e9b..2dd0a885360 100644 --- a/test/verify/test_layernorm.cpp +++ b/test/verify/test_layernorm.cpp @@ -117,6 +117,18 @@ struct test_layernorm_fp16 : verify_program } }; +// struct test_layernorm_fp8 : verify_program +// { +// migraphx::program create_program() const +// { +// migraphx::program p; +// auto* mm = p.get_main_module(); +// std::vector dims = {1, 24, 64}; +// auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, +// dims}); add_layernorm(*mm, x, dims); return p; +// } +// }; + struct test_layernorm_eps : verify_program { migraphx::program create_program() const diff --git a/test/verify/test_log.cpp b/test/verify/test_log.cpp index c12105f3f30..2670632defa 100644 --- a/test/verify/test_log.cpp +++ b/test/verify/test_log.cpp @@ -27,15 +27,20 @@ #include #include -struct test_log : verify_program +template +struct test_log : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {6}}; + migraphx::shape s{DType, {6}}; auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s)); mm->add_instruction(migraphx::make_op("log"), x); return p; } }; + +template struct test_log; +template struct test_log; +template struct test_log; diff --git a/test/verify/test_min_max.cpp b/test/verify/test_min_max.cpp index 9cc60d8a6f6..acfeefb8851 100644 --- a/test/verify/test_min_max.cpp +++ b/test/verify/test_min_max.cpp @@ -46,7 +46,9 @@ struct test_min_max : verify_program> template struct test_min_max; template struct test_min_max; template struct test_min_max; +template struct test_min_max; template struct test_min_max; template struct test_min_max; template struct test_min_max; +template struct test_min_max; diff --git a/test/verify/test_nearbyint.cpp b/test/verify/test_nearbyint.cpp index 8cdf0c0b410..65d040d5544 100644 --- a/test/verify/test_nearbyint.cpp +++ b/test/verify/test_nearbyint.cpp @@ -22,6 +22,7 @@ * THE SOFTWARE. */ +#include "migraphx/float8.hpp" #include "verify_program.hpp" #include #include @@ -45,3 +46,4 @@ struct test_nearbyint : verify_program> template struct test_nearbyint; template struct test_nearbyint; +template struct test_nearbyint; diff --git a/test/verify/test_pad.cpp b/test/verify/test_pad.cpp index 31bb9cece61..21d20134f78 100644 --- a/test/verify/test_pad.cpp +++ b/test/verify/test_pad.cpp @@ -27,13 +27,14 @@ #include #include -struct test_pad : verify_program +template +struct test_pad : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s0{migraphx::shape::int32_type, {1, 96, 165, 165}}; + migraphx::shape s0{DType, {1, 96, 165, 165}}; std::vector pads0 = {0, 0, 0, 0, 0, 0, 1, 1}; std::vector pads1 = {0, 0, 0, 0, 1, 1, 1, 1}; std::vector pads2 = {1, 1, 1, 1, 0, 0, 0, 0}; @@ -46,3 +47,8 @@ struct test_pad : verify_program return p; } }; + +template struct test_pad; +template struct test_pad; +template struct test_pad; +// template struct test_pad; diff --git a/test/verify/test_pow.cpp b/test/verify/test_pow.cpp index 861f5063aa9..abc6abea33f 100644 --- a/test/verify/test_pow.cpp +++ b/test/verify/test_pow.cpp @@ -41,3 +41,4 @@ struct test_pow : verify_program return p; } }; +// TODO: add fp8 tests diff --git a/test/verify/test_reduce_add.cpp b/test/verify/test_reduce_add.cpp index e7c1b56b6c1..7bdd9749544 100644 --- a/test/verify/test_reduce_add.cpp +++ b/test/verify/test_reduce_add.cpp @@ -22,19 +22,21 @@ * THE SOFTWARE. */ +#include "migraphx/shape.hpp" #include "verify_program.hpp" #include #include #include #include -struct test_reduce_add : verify_program +template +struct test_reduce_add : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {4, 1000, 2, 2}}; + migraphx::shape s{DType, {4, 1000, 2, 2}}; migraphx::shape bs{migraphx::shape::half_type, {1, 32, 128}}; auto x = mm->add_parameter("x", s); auto reduce_mean = @@ -46,3 +48,6 @@ struct test_reduce_add : verify_program return p; }; }; + +template struct test_reduce_add; +template struct test_reduce_add; diff --git a/test/verify/test_reduce_mean_nhwc.cpp b/test/verify/test_reduce_mean_nhwc.cpp index ef4251c8ab4..3c044992eb5 100644 --- a/test/verify/test_reduce_mean_nhwc.cpp +++ b/test/verify/test_reduce_mean_nhwc.cpp @@ -28,14 +28,14 @@ #include #include -struct test_reduce_mean_nhwc : verify_program +template +struct test_reduce_mean_nhwc : verify_program> { migraphx::program create_program() const { migraphx::program p; - auto* mm = p.get_main_module(); - auto s = migraphx::shape::from_permutation( - migraphx::shape::float_type, {4, 256, 2, 2}, {0, 2, 3, 1}); + auto* mm = p.get_main_module(); + auto s = migraphx::shape::from_permutation(DType, {4, 256, 2, 2}, {0, 2, 3, 1}); auto x = mm->add_parameter("x", s); auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x); auto abs = mm->add_instruction(migraphx::make_op("abs"), reduce); @@ -44,3 +44,7 @@ struct test_reduce_mean_nhwc : verify_program return p; }; }; + +template struct test_reduce_mean_nhwc; +template struct test_reduce_mean_nhwc; +template struct test_reduce_mean_nhwc; diff --git a/test/verify/test_reduce_op_large.cpp b/test/verify/test_reduce_op_large.cpp index 0b838468e20..11f61452836 100644 --- a/test/verify/test_reduce_op_large.cpp +++ b/test/verify/test_reduce_op_large.cpp @@ -51,6 +51,22 @@ template struct test_reduce_op_large; template struct test_reduce_op_large; +template struct test_reduce_op_large; +template struct test_reduce_op_large; +template struct test_reduce_op_large; +template struct test_reduce_op_large; +template struct test_reduce_op_large; + struct test_reduce_mean_1 : verify_program { migraphx::program create_program() const diff --git a/test/verify/test_reduce_op_small.cpp b/test/verify/test_reduce_op_small.cpp index 2db9c00eca0..81167eda815 100644 --- a/test/verify/test_reduce_op_small.cpp +++ b/test/verify/test_reduce_op_small.cpp @@ -56,3 +56,19 @@ template struct test_reduce_op_small; template struct test_reduce_op_small; template struct test_reduce_op_small; + +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; +template struct test_reduce_op_small; diff --git a/test/verify/test_roialign.cpp b/test/verify/test_roialign.cpp index e153f6fdd0f..f235462ae1d 100644 --- a/test/verify/test_roialign.cpp +++ b/test/verify/test_roialign.cpp @@ -27,15 +27,16 @@ #include #include -struct test_roialign : verify_program +template +struct test_roialign : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape x_s{migraphx::shape::float_type, {5, 4, 10, 10}}; + migraphx::shape x_s{DType, {5, 4, 10, 10}}; - migraphx::shape roi_s{migraphx::shape::float_type, {5, 4}}; + migraphx::shape roi_s{DType, {5, 4}}; migraphx::shape ind_s{migraphx::shape::int64_type, {5}}; std::vector ind_vec = {0, 2, 3, 4, 1}; @@ -44,10 +45,10 @@ struct test_roialign : verify_program auto roi = mm->add_parameter("roi", roi_s); auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); auto r = mm->add_instruction(migraphx::make_op("roialign", - {{"spatial_scale", 1.0}, - {"output_height", 5}, - {"output_width", 5}, - {"sampling_ratio", 2}}), + {{"spatial_scale", 1.0}, + {"output_height", 5}, + {"output_width", 5}, + {"sampling_ratio", 2}}), x, roi, ind); @@ -56,3 +57,7 @@ struct test_roialign : verify_program return p; } }; + +template struct test_roialign; +// template struct test_roialign; +// template struct test_roialign; diff --git a/test/verify/test_rsqrt.cpp b/test/verify/test_rsqrt.cpp index 7ce2dabd81b..d6539725331 100644 --- a/test/verify/test_rsqrt.cpp +++ b/test/verify/test_rsqrt.cpp @@ -47,3 +47,5 @@ struct test_rsqrt : verify_program return p; }; }; + +// TOOD : Add FP8 test \ No newline at end of file diff --git a/test/verify/test_scatternd.cpp b/test/verify/test_scatternd.cpp index f90f6238997..977628143a2 100644 --- a/test/verify/test_scatternd.cpp +++ b/test/verify/test_scatternd.cpp @@ -21,22 +21,23 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/shape.hpp" #include "verify_program.hpp" #include #include #include -struct test_scatternd : verify_program +template +struct test_scatternd : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto dtype = migraphx::shape::float_type; auto itype = migraphx::shape::int64_type; - migraphx::shape ds{dtype, {1}}; + migraphx::shape ds{DType, {1}}; migraphx::shape is{itype, {4, 1}}; - migraphx::shape us{dtype, {4}}; + migraphx::shape us{DType, {4}}; std::vector ind_vec{4, 3, 1, 7}; auto ld = mm->add_literal(migraphx::literal{ds, {1}}); @@ -51,3 +52,7 @@ struct test_scatternd : verify_program return p; } }; + +template struct test_scatternd; +template struct test_scatternd; +template struct test_scatternd; diff --git a/test/verify/test_sin.cpp b/test/verify/test_sin.cpp index 33ef1603551..ec5885ec3e8 100644 --- a/test/verify/test_sin.cpp +++ b/test/verify/test_sin.cpp @@ -27,15 +27,20 @@ #include #include -struct test_sin : verify_program +template +struct test_sin : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {10}}; + migraphx::shape s{DType, {10}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("sin"), x); return p; } }; + +template struct test_sin; +template struct test_sin; +template struct test_sin; diff --git a/test/verify/test_sinh.cpp b/test/verify/test_sinh.cpp index ebeb381e78a..5accaf8cda2 100644 --- a/test/verify/test_sinh.cpp +++ b/test/verify/test_sinh.cpp @@ -27,15 +27,20 @@ #include #include -struct test_sinh : verify_program +template +struct test_sinh : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {16}}; + migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("sinh"), x); return p; } }; + +template struct test_sinh; +template struct test_sinh; +template struct test_sinh; diff --git a/test/verify/test_softmax.cpp b/test/verify/test_softmax.cpp index 1926e95d219..255b2766ec1 100644 --- a/test/verify/test_softmax.cpp +++ b/test/verify/test_softmax.cpp @@ -48,3 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>; template struct test_softmax<1, migraphx::shape::half_type>; template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<3, migraphx::shape::half_type>; +// template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>; +// template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>; +// template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>; +// template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>; diff --git a/test/verify/test_sqrt.cpp b/test/verify/test_sqrt.cpp index ba5de80ac47..105cc58cad5 100644 --- a/test/verify/test_sqrt.cpp +++ b/test/verify/test_sqrt.cpp @@ -27,16 +27,21 @@ #include #include -struct test_sqrt : verify_program +template +struct test_sqrt : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; + migraphx::shape s{DType, {2, 3, 4, 6}}; auto param = mm->add_parameter("x", s); auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param); mm->add_instruction(migraphx::make_op("sqrt"), param_abs); return p; } }; + +template struct test_sqrt; +template struct test_sqrt; +template struct test_sqrt; diff --git a/test/verify/test_tan.cpp b/test/verify/test_tan.cpp index d8e8ab3664d..b14ab324535 100644 --- a/test/verify/test_tan.cpp +++ b/test/verify/test_tan.cpp @@ -27,15 +27,20 @@ #include #include -struct test_tan : verify_program +template +struct test_tan : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {16}}; + migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("tan"), x); return p; } }; + +template struct test_tan; +template struct test_tan; +template struct test_tan; diff --git a/test/verify/test_tanh.cpp b/test/verify/test_tanh.cpp index 39ed4283a68..fb8a3d1ea77 100644 --- a/test/verify/test_tanh.cpp +++ b/test/verify/test_tanh.cpp @@ -27,14 +27,19 @@ #include #include -struct test_tanh : verify_program +template +struct test_tanh : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto x = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); mm->add_instruction(migraphx::make_op("tanh"), x); return p; } }; + +template struct test_tanh; +template struct test_tanh; +template struct test_tanh; diff --git a/test/verify/test_where.cpp b/test/verify/test_where.cpp index 2f9d3b68920..d14d906f37e 100644 --- a/test/verify/test_where.cpp +++ b/test/verify/test_where.cpp @@ -27,7 +27,8 @@ #include #include -struct test_where : verify_program +template +struct test_where : verify_program> { migraphx::program create_program() const { @@ -44,3 +45,7 @@ struct test_where : verify_program return p; }; }; + +template struct test_where; +template struct test_where; +template struct test_where; From e2b0c40634524795f5cceb0d146f77bc03aa8928 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sat, 18 Nov 2023 16:00:16 +0000 Subject: [PATCH 046/115] Fix tidy and other errors --- src/targets/cpu/dnnl.cpp | 1 + .../include/migraphx/kernels/float8.hpp | 22 ++++++++++--------- .../include/migraphx/kernels/float8_impl.hpp | 2 ++ test/verify/test_fmod_mod.cpp | 2 +- test/verify/test_literal_limits.cpp | 7 +++++- test/verify/test_rsqrt.cpp | 2 +- 6 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/targets/cpu/dnnl.cpp b/src/targets/cpu/dnnl.cpp index 3e5c1c1d066..a36fb688a4a 100644 --- a/src/targets/cpu/dnnl.cpp +++ b/src/targets/cpu/dnnl.cpp @@ -68,6 +68,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t) case st::int32_type: return dt::s32; case st::int8_type: return dt::s8; case st::uint8_type: return dt::u8; + case st::fp8e4m3fnuz_type: return dt : u8; default: MIGRAPHX_THROW("Unsupported data type"); } } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 6f4219a627c..1996bb9f634 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -29,7 +29,7 @@ #define MIGRAPHX_HIP_DEVICE __device__ // We are clipping in down conversion by default -#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 +#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 // NOLINT #include #include @@ -178,7 +178,7 @@ struct float8 // convert to float // #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) -#if 0 // need constexpr operator(). This version can't be constexpr +#if 0 // need constexpr operator(). This version can't be constexpr // NOLINT // upcast using device specific intrinsic inline MIGRAPHX_HIP_DEVICE operator float() const { @@ -264,6 +264,7 @@ struct float8 } } +// NOLINTNEXTLINE #define MIGRAPHX_FP8_SHORT_UNARY_OP(unary_op, binary_op) \ constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float8& rhs) \ { \ @@ -324,13 +325,14 @@ using fp8e5m2fnuz = float8; } // NOLINTNEXTLINE -#define MIGRAPHX_FP8_UNARY_OP(unary_op, T) \ - inline constexpr MIGRAPHX_HIP_DEVICE T unary_op(T v) \ - { \ - v.data = v.data & 0x7f; \ - return v; \ +#define MIGRAPHX_FP8_FABS(T) \ + inline constexpr MIGRAPHX_HIP_DEVICE T fabs(T v) \ + { \ + v.data = v.data & 0x7f; \ + return v; \ } +// NOLINTNEXTLINE #define MIGRAPHX_FP8_GEN_OP_OVERLOADS(T) \ MIGRAPHX_FP8_BINARY_OP(*, T, T) \ MIGRAPHX_FP8_BINARY_OP(-, T, T) \ @@ -342,7 +344,7 @@ using fp8e5m2fnuz = float8; MIGRAPHX_FP8_BINARY_OP(>, T, bool) \ MIGRAPHX_FP8_BINARY_OP(<, T, bool) \ MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \ - MIGRAPHX_FP8_UNARY_OP(fabs, T) + MIGRAPHX_FP8_FABS(T) MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2) MIGRAPHX_FP8_GEN_OP_OVERLOADS(fp8e5m2fnuz) @@ -453,10 +455,10 @@ class numeric_limits return fp8e5m2(0x34, fp8e5m2::from_bits()); } // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 quiet_NaN() + static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 quiet_NaN() // NOLINT { return fp8e5m2(0xFF, fp8e5m2::from_bits()); - } // NOLINT + } static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 max() { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp index 142a93b8873..95477c1b120 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp @@ -47,6 +47,7 @@ struct conditional namespace fp8 { namespace impl { +// NOLINTBEGIN template __device__ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) { @@ -256,6 +257,7 @@ __device__ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng mantissa &= (1 << Wm) - 1; return (sign << 7) | (f8_exponent << Wm) | mantissa; } +// NOLINTEND template __device__ constexpr T cast_from_f8(uint8_t x) diff --git a/test/verify/test_fmod_mod.cpp b/test/verify/test_fmod_mod.cpp index 35a1bc383ee..4701a7e966a 100644 --- a/test/verify/test_fmod_mod.cpp +++ b/test/verify/test_fmod_mod.cpp @@ -71,4 +71,4 @@ struct test_mod : verify_program return p; } }; -// TODO: check if requires FP8 test \ No newline at end of file +// TODO: check if requires FP8 test diff --git a/test/verify/test_literal_limits.cpp b/test/verify/test_literal_limits.cpp index fa0828585e1..baa21b4e749 100644 --- a/test/verify/test_literal_limits.cpp +++ b/test/verify/test_literal_limits.cpp @@ -26,6 +26,7 @@ #include #include #include +#include template struct test_literal_limits : verify_program> @@ -36,10 +37,14 @@ struct test_literal_limits : verify_program> auto* mm = p.get_main_module(); auto input_s = migraphx::shape(Q, {3, 1}); auto infinity_val = std::numeric_limits::max(); - if constexpr(std::numeric_limits::has_infinity) + if constexpr(std::numeric_limits::has_infinity and std::is_floating_point{}) { infinity_val = std::numeric_limits::infinity(); } + else + { // for the interger vals, infinity doesn't exist + infinity_val = 0; + } std::vector s_data{ infinity_val, static_cast(-infinity_val), std::numeric_limits::quiet_NaN()}; diff --git a/test/verify/test_rsqrt.cpp b/test/verify/test_rsqrt.cpp index d6539725331..2862196b095 100644 --- a/test/verify/test_rsqrt.cpp +++ b/test/verify/test_rsqrt.cpp @@ -48,4 +48,4 @@ struct test_rsqrt : verify_program }; }; -// TOOD : Add FP8 test \ No newline at end of file +// TOOD : Add FP8 test From 9f50051e0c15ffee9292e00538b892715b0722a7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sat, 18 Nov 2023 16:12:40 +0000 Subject: [PATCH 047/115] fixes --- src/targets/cpu/dnnl.cpp | 2 +- test/verify/test_literal_limits.cpp | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/targets/cpu/dnnl.cpp b/src/targets/cpu/dnnl.cpp index a36fb688a4a..ee3546406e5 100644 --- a/src/targets/cpu/dnnl.cpp +++ b/src/targets/cpu/dnnl.cpp @@ -68,7 +68,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t) case st::int32_type: return dt::s32; case st::int8_type: return dt::s8; case st::uint8_type: return dt::u8; - case st::fp8e4m3fnuz_type: return dt : u8; + case st::fp8e4m3fnuz_type: return dt::u8; default: MIGRAPHX_THROW("Unsupported data type"); } } diff --git a/test/verify/test_literal_limits.cpp b/test/verify/test_literal_limits.cpp index baa21b4e749..a7585f72e81 100644 --- a/test/verify/test_literal_limits.cpp +++ b/test/verify/test_literal_limits.cpp @@ -34,17 +34,13 @@ struct test_literal_limits : verify_program> migraphx::program create_program() const { migraphx::program p; - auto* mm = p.get_main_module(); - auto input_s = migraphx::shape(Q, {3, 1}); - auto infinity_val = std::numeric_limits::max(); + auto* mm = p.get_main_module(); + auto input_s = migraphx::shape(Q, {3, 1}); + T infinity_val{0}; if constexpr(std::numeric_limits::has_infinity and std::is_floating_point{}) { infinity_val = std::numeric_limits::infinity(); } - else - { // for the interger vals, infinity doesn't exist - infinity_val = 0; - } std::vector s_data{ infinity_val, static_cast(-infinity_val), std::numeric_limits::quiet_NaN()}; From 249464c12f4d7f06c2fc09375f88c5dd96f4dccd Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sat, 18 Nov 2023 16:14:52 +0000 Subject: [PATCH 048/115] add nolint --- src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 1996bb9f634..f988c39519a 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -328,6 +328,7 @@ using fp8e5m2fnuz = float8; #define MIGRAPHX_FP8_FABS(T) \ inline constexpr MIGRAPHX_HIP_DEVICE T fabs(T v) \ { \ + /*NOLINTNEXTLINE*/ \ v.data = v.data & 0x7f; \ return v; \ } From 1be958709747b445cbcc07a626531922dcde7bb9 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sat, 18 Nov 2023 16:20:05 +0000 Subject: [PATCH 049/115] tidy fix --- src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index f988c39519a..82a9fd7afbd 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -284,8 +284,8 @@ struct float8 MIGRAPHX_FP8_SHORT_UNARY_OP(+=, +) MIGRAPHX_FP8_SHORT_UNARY_OP(/=, /) - inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default; - inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) = default; + inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default; + inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) noexcept = default; inline MIGRAPHX_HIP_DEVICE constexpr bool operator==(const float8& rhs) const { From 13403ab2611625d1739511d79160d6c5a0e2f496 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 15:11:44 +0000 Subject: [PATCH 050/115] roialign, softmax, pow, acosh, atanh,pad tests are enabled now --- src/targets/cpu/dnnl.cpp | 2 +- .../include/migraphx/kernels/float8.hpp | 43 +++++++++++- .../kernels/include/migraphx/kernels/pad.hpp | 5 +- .../include/migraphx/kernels/reduce.hpp | 2 +- .../include/migraphx/kernels/roialign.hpp | 69 ++++++++++--------- .../include/migraphx/kernels/softmax.hpp | 3 +- test/verify/test_acosh.cpp | 18 ++--- test/verify/test_atanh.cpp | 19 ++--- test/verify/test_pad.cpp | 2 +- test/verify/test_pow.cpp | 12 ++-- test/verify/test_roialign.cpp | 4 +- test/verify/test_softmax.cpp | 8 +-- 12 files changed, 121 insertions(+), 66 deletions(-) diff --git a/src/targets/cpu/dnnl.cpp b/src/targets/cpu/dnnl.cpp index ee3546406e5..4fff4cca8d7 100644 --- a/src/targets/cpu/dnnl.cpp +++ b/src/targets/cpu/dnnl.cpp @@ -67,7 +67,7 @@ dnnl::memory::data_type to_dnnl_memory_data_type(shape::type_t t) case st::float_type: return dt::f32; case st::int32_type: return dt::s32; case st::int8_type: return dt::s8; - case st::uint8_type: return dt::u8; + case st::uint8_type: case st::fp8e4m3fnuz_type: return dt::u8; default: MIGRAPHX_THROW("Unsupported data type"); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 82a9fd7afbd..17c4ccf9b51 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -25,6 +25,7 @@ #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wfloat-equal" +#pragma clang diagnostic ignored "-Wold-style-cast" #endif // __clang__ #define MIGRAPHX_HIP_DEVICE __device__ @@ -132,7 +133,7 @@ struct float8 // NOTE: ON-DEVICE... always optimal bias explicit constexpr MIGRAPHX_HIP_DEVICE - float8(float v, + float8(const float v, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, uint32_t rng = 0) { @@ -145,8 +146,7 @@ struct float8 #else // DEVICE for non-gfx940 using s/w simulation explicit constexpr MIGRAPHX_HIP_DEVICE -#endif - float8(float v, + float8(const float v, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, uint32_t rng = 0) { @@ -175,7 +175,42 @@ struct float8 #endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING} } } +#endif // __gfx940___ + + // Constructor from half + explicit constexpr MIGRAPHX_HIP_DEVICE + float8(const _Float16 v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) + : float8((float)v, rm, rng) + { + } + // constructor from int + explicit constexpr MIGRAPHX_HIP_DEVICE + float8(const int v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) + : float8((float)v, rm, rng) + { + } + + // constructor from uint + explicit constexpr MIGRAPHX_HIP_DEVICE + float8(const uint32_t v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) + : float8((float)v, rm, rng) + { + } + + // constructor from double + explicit constexpr MIGRAPHX_HIP_DEVICE + float8(const double v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) + : float8((float)v, rm, rng) + { + } + + // constructor from bool + explicit constexpr MIGRAPHX_HIP_DEVICE + float8(const bool v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) + : float8((float)(v), rm, rng) + { + } // convert to float // #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if 0 // need constexpr operator(). This version can't be constexpr // NOLINT @@ -209,6 +244,8 @@ struct float8 return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); } + inline constexpr explicit MIGRAPHX_HIP_DEVICE operator bool() const { return not is_zero(); } + // check for zero inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp index d4dcb49dfad..ac9cfd4a7a5 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp @@ -39,6 +39,7 @@ __device__ void pad(const index& idx, const PadVal& pad_val) { auto output_shape = output.get_shape(); + using otype = typename Output::type; idx.global_stride(output_shape.elements(), [&](auto i) { // 1. get current multi-index for output // 2. get the size of the input to determine input boundaries @@ -53,9 +54,9 @@ __device__ void pad(const index& idx, if(any_of(range_multi.begin(), range_multi.end(), [&](auto j) { return multi[j] < offsets[j] or input_idx[j] >= input_bounds[j]; })) - output[multi] = pad_val; + output[multi] = otype(pad_val); else - output[multi] = input[input_idx]; + output[multi] = otype(input[input_idx]); }); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index a106773d1dc..98c7d9543b2 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -392,7 +392,7 @@ struct block { using max_iterations = decltype(idx.max_local_stride_iterations(n)); inner_storage storage; - idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); }); + idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = R{f(xs(j, d)...)}; }); return storage; } }; diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp index 5d00a570a2f..9f0b538a439 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp @@ -56,13 +56,13 @@ struct avg_pool template MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) { - return x + y; + return static_cast(x + y); } template MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y) { - return (y == 0) ? 0.0 : (x / y); + return (y == 0) ? static_cast(0.0) : static_cast(x / y); } }; @@ -70,13 +70,14 @@ template MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( const Iterator data, const array& dims, array xy, Op pooling) { + using ret_type = typename Iterator::value_type; array low{}; array high{}; for(index_int ii = 0; ii < xy.size(); ++ii) { if(xy[ii] < -1.0f or xy[ii] > dims[ii]) { - return 0; + return static_cast(0); } xy[ii] = migraphx::max(xy[ii], 0.0f); @@ -92,11 +93,14 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( high[0] * dims[1] + low[1], high[0] * dims[1] + high[1]}; - float ly = xy[0] - low[0]; - float lx = xy[1] - low[1]; - float hy = 1.0f - ly; - float hx = 1.0f - lx; - array ws = {hy * hx, hy * lx, ly * hx, ly * lx}; + float ly = xy[0] - low[0]; + float lx = xy[1] - low[1]; + float hy = 1.0f - ly; + float hx = 1.0f - lx; + array ws = {static_cast(hy * hx), + static_cast(hy * lx), + static_cast(ly * hx), + static_cast(ly * lx)}; auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); @@ -113,8 +117,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data, float roi_offset, Op op) { - typename Iterator::value_type output_val = op.init(); - const int64_t count = bin_grid_size[0] * bin_grid_size[1]; + using in_dtype = typename Iterator::value_type; + in_dtype output_val = in_dtype{op.init()}; + const int64_t count = bin_grid_size[0] * bin_grid_size[1]; dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { array id = {iy, ix}; array locs = @@ -148,7 +153,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, const auto x = x_t.begin(); const auto rois = rois_t.begin(); const auto ind = ind_t.begin(); - + using ytype = typename W::type; // input shape auto x_lens = x_t.get_shape().lens; auto channel_num = x_lens[1]; @@ -176,10 +181,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, const auto offset_rois = rois + (n * roi_column_num); const int batch_ind = ind[n]; - array roi_starts = {offset_rois[1] * s.spatial_scale, - offset_rois[0] * s.spatial_scale}; - array roi_ends = {offset_rois[3] * s.spatial_scale, - offset_rois[2] * s.spatial_scale}; + array roi_starts = { + static_cast(offset_rois[1]) * static_cast(s.spatial_scale), + static_cast(offset_rois[0]) * static_cast(s.spatial_scale)}; + array roi_ends = { + static_cast(offset_rois[3]) * static_cast(s.spatial_scale), + static_cast(offset_rois[2]) * static_cast(s.spatial_scale)}; array roi_size{}; array bin_size{}; @@ -199,25 +206,25 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); if constexpr(s.is_avg_pooling) { - y_t[i] = calc_pooling(offset_x, - roi_starts, - bin_size, - {ph, pw}, - bin_grid_size, - in_dims, - s.roi_offset, - avg_pool{}); + y_t[i] = static_cast(calc_pooling(offset_x, + roi_starts, + bin_size, + {ph, pw}, + bin_grid_size, + in_dims, + s.roi_offset, + avg_pool{})); } else { - y_t[i] = calc_pooling(offset_x, - roi_starts, - bin_size, - {ph, pw}, - bin_grid_size, - in_dims, - s.roi_offset, - max_pool{}); + y_t[i] = static_cast(calc_pooling(offset_x, + roi_starts, + bin_size, + {ph, pw}, + bin_grid_size, + in_dims, + s.roi_offset, + max_pool{})); } } } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp index b0242b302f8..47997e0111e 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp @@ -33,6 +33,7 @@ template __device__ void softmax(Input input1, Output output) { using block = reduce::auto_block()>; + using otype = typename Output::type; block::template run>([&](auto, auto r) { auto input = r.inner(op::id{})(input1); #ifdef MIGRAPHX_USE_FAST_SOFTMAX @@ -43,7 +44,7 @@ __device__ void softmax(Input input1, Output output) auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto batch_sum = r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert(x); })(exp_in); - r.inner([&](auto& y, auto x) { y = x / batch_sum; })(output, exp_in); + r.inner([&](auto& y, auto x) { y = otype{x / batch_sum}; })(output, exp_in); }); } diff --git a/test/verify/test_acosh.cpp b/test/verify/test_acosh.cpp index 9acea66cc58..5125bdf0eec 100644 --- a/test/verify/test_acosh.cpp +++ b/test/verify/test_acosh.cpp @@ -23,21 +23,23 @@ */ #include "verify_program.hpp" +#include #include #include #include -template -struct test_acosh : verify_program> +template +struct test_acosh : verify_program> { migraphx::program create_program() const { migraphx::program p; - auto* mm = p.get_main_module(); + auto* mm = p.get_main_module(); + migraphx::shape::type_t DType = migraphx::shape::get_type(); migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); - auto min_val = mm->add_literal(1.1f); - auto max_val = mm->add_literal(100.0f); + auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {1.1}}); + auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {100.0}}); min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); max_val = @@ -48,6 +50,6 @@ struct test_acosh : verify_program> } }; -template struct test_acosh; -// template struct test_acosh; -// template struct test_acosh; +template struct test_acosh; +template struct test_acosh; +template struct test_acosh; diff --git a/test/verify/test_atanh.cpp b/test/verify/test_atanh.cpp index ed842aa7d6c..882a3184a0a 100644 --- a/test/verify/test_atanh.cpp +++ b/test/verify/test_atanh.cpp @@ -23,21 +23,24 @@ */ #include "verify_program.hpp" +#include +#include #include #include #include -template -struct test_atanh : verify_program> +template +struct test_atanh : verify_program> { migraphx::program create_program() const { migraphx::program p; - auto* mm = p.get_main_module(); + auto* mm = p.get_main_module(); + migraphx::shape::type_t DType = migraphx::shape::get_type(); migraphx::shape s{DType, {16}}; auto x = mm->add_parameter("x", s); - auto min_val = mm->add_literal(-0.95f); - auto max_val = mm->add_literal(0.95f); + auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {-0.95f}}); + auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {0.95f}}); min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); max_val = @@ -48,6 +51,6 @@ struct test_atanh : verify_program> } }; -template struct test_atanh; -// template struct test_atanh; -// template struct test_atanh; +template struct test_atanh; +template struct test_atanh; +template struct test_atanh; diff --git a/test/verify/test_pad.cpp b/test/verify/test_pad.cpp index 21d20134f78..5366fe4b03f 100644 --- a/test/verify/test_pad.cpp +++ b/test/verify/test_pad.cpp @@ -51,4 +51,4 @@ struct test_pad : verify_program> template struct test_pad; template struct test_pad; template struct test_pad; -// template struct test_pad; +template struct test_pad; diff --git a/test/verify/test_pow.cpp b/test/verify/test_pow.cpp index abc6abea33f..8dd8d96e76e 100644 --- a/test/verify/test_pow.cpp +++ b/test/verify/test_pow.cpp @@ -27,13 +27,15 @@ #include #include -struct test_pow : verify_program +template +struct test_pow : verify_program> { migraphx::program create_program() const { migraphx::program p; - auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {6}}; + migraphx::shape::type_t DType = migraphx::shape::get_type(); + auto* mm = p.get_main_module(); + migraphx::shape s{DType, {6}}; std::vector vec_e(s.elements(), 2.0f); auto b = mm->add_parameter("x", s); auto e = mm->add_literal(migraphx::literal(s, vec_e)); @@ -41,4 +43,6 @@ struct test_pow : verify_program return p; } }; -// TODO: add fp8 tests +template struct test_pow; +template struct test_pow; +template struct test_pow; diff --git a/test/verify/test_roialign.cpp b/test/verify/test_roialign.cpp index f235462ae1d..6db58d2a076 100644 --- a/test/verify/test_roialign.cpp +++ b/test/verify/test_roialign.cpp @@ -59,5 +59,5 @@ struct test_roialign : verify_program> }; template struct test_roialign; -// template struct test_roialign; -// template struct test_roialign; +template struct test_roialign; +template struct test_roialign; diff --git a/test/verify/test_softmax.cpp b/test/verify/test_softmax.cpp index 255b2766ec1..d966bd68bb5 100644 --- a/test/verify/test_softmax.cpp +++ b/test/verify/test_softmax.cpp @@ -48,7 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>; template struct test_softmax<1, migraphx::shape::half_type>; template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<3, migraphx::shape::half_type>; -// template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>; -// template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>; -// template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>; -// template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>; +template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>; +template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>; +template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>; +template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>; From f550f814ffffdf46582eae7bd83c70d9245e6ce5 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 17:49:25 +0000 Subject: [PATCH 051/115] add layernorm, remove constexpr for 1/r --- .../include/migraphx/kernels/layernorm.hpp | 25 +++++++++++-------- test/verify/test_layernorm.cpp | 23 +++++++++-------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp index d2e0fe7445d..0958bf057f3 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp @@ -52,22 +52,25 @@ __device__ void generic_binary_layernorm( block::template run([&](auto, auto r) { auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); using value_type = typename Input1::type; - constexpr auto relements = r.template elements(); - constexpr auto relements_r = vec_type{1.0 / relements}; - auto relements_rsqrt = sqrt(relements_r); + using vec_value_type = vec_type; + constexpr auto relements = r.template elements(); + auto relements_r = vec_value_type{1.0 / relements}; + auto relements_rsqrt = sqrt(relements_r); - auto means = r.reduce(op::sum{}, make_array>(0, 0), [&](auto x) { - auto x_out = x * relements_r; - // dividing x by sqrt(relements) before squaring allows computing higher values - // before overflow in low precision - auto x2_sqrt = x * relements_rsqrt; - return make_array(x_out, x2_sqrt * x2_sqrt); - })(input); + auto means = r.reduce(op::sum{}, + make_array(vec_value_type{0}, vec_value_type{0}), + [&](auto x) { + auto x_out = x * relements_r; + // dividing x by sqrt(relements) before squaring allows computing + // higher values before overflow in low precision + auto x2_sqrt = x * relements_rsqrt; + return make_array(x_out, x2_sqrt * x2_sqrt); + })(input); auto mean_x = means[0]; auto mean_x2 = means[1]; auto variance = mean_x2 - (mean_x * mean_x); - value_type eps_val = eps; // implicit conversion for eps + value_type eps_val = value_type{eps}; r.inner([&](auto& y, auto x, auto... xs) { auto m = x - mean_x; diff --git a/test/verify/test_layernorm.cpp b/test/verify/test_layernorm.cpp index 2dd0a885360..bfcb87b5920 100644 --- a/test/verify/test_layernorm.cpp +++ b/test/verify/test_layernorm.cpp @@ -117,17 +117,18 @@ struct test_layernorm_fp16 : verify_program } }; -// struct test_layernorm_fp8 : verify_program -// { -// migraphx::program create_program() const -// { -// migraphx::program p; -// auto* mm = p.get_main_module(); -// std::vector dims = {1, 24, 64}; -// auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, -// dims}); add_layernorm(*mm, x, dims); return p; -// } -// }; +struct test_layernorm_fp8 : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dims = {1, 24, 64}; + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, dims}); + add_layernorm(*mm, x, dims); + return p; + } +}; struct test_layernorm_eps : verify_program { From 7e3444ce3d12fe8f49ce2b3894983e026b3c8acb Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 17:52:52 +0000 Subject: [PATCH 052/115] tidy fixes --- test/verify/test_acosh.cpp | 8 ++++---- test/verify/test_atanh.cpp | 8 ++++---- test/verify/test_pow.cpp | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/verify/test_acosh.cpp b/test/verify/test_acosh.cpp index 5125bdf0eec..b8bd655c360 100644 --- a/test/verify/test_acosh.cpp +++ b/test/verify/test_acosh.cpp @@ -35,11 +35,11 @@ struct test_acosh : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape::type_t DType = migraphx::shape::get_type(); - migraphx::shape s{DType, {16}}; + migraphx::shape::type_t dtype = migraphx::shape::get_type(); + migraphx::shape s{dtype, {16}}; auto x = mm->add_parameter("x", s); - auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {1.1}}); - auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {100.0}}); + auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {1.1}}); + auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {100.0}}); min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); max_val = diff --git a/test/verify/test_atanh.cpp b/test/verify/test_atanh.cpp index 882a3184a0a..f2fdb7d993b 100644 --- a/test/verify/test_atanh.cpp +++ b/test/verify/test_atanh.cpp @@ -36,11 +36,11 @@ struct test_atanh : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape::type_t DType = migraphx::shape::get_type(); - migraphx::shape s{DType, {16}}; + migraphx::shape::type_t dtype = migraphx::shape::get_type(); + migraphx::shape s{dtype, {16}}; auto x = mm->add_parameter("x", s); - auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {-0.95f}}); - auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{DType}, {0.95f}}); + auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {-0.95f}}); + auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {0.95f}}); min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); max_val = diff --git a/test/verify/test_pow.cpp b/test/verify/test_pow.cpp index 8dd8d96e76e..8b57df68a03 100644 --- a/test/verify/test_pow.cpp +++ b/test/verify/test_pow.cpp @@ -33,9 +33,9 @@ struct test_pow : verify_program> migraphx::program create_program() const { migraphx::program p; - migraphx::shape::type_t DType = migraphx::shape::get_type(); + migraphx::shape::type_t dtype = migraphx::shape::get_type(); auto* mm = p.get_main_module(); - migraphx::shape s{DType, {6}}; + migraphx::shape s{dtype, {6}}; std::vector vec_e(s.elements(), 2.0f); auto b = mm->add_parameter("x", s); auto e = mm->add_literal(migraphx::literal(s, vec_e)); From 6155c7822a57a96b95ed683887859b53f2a30e73 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 19:45:42 +0000 Subject: [PATCH 053/115] use __builtin_is_constant_evaluated --- .../include/migraphx/kernels/float8.hpp | 208 ++++++++++-------- .../include/migraphx/kernels/layernorm.hpp | 13 +- .../include/migraphx/kernels/softmax.hpp | 2 +- 3 files changed, 122 insertions(+), 101 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 17c4ccf9b51..469b1913354 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -26,8 +26,8 @@ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wfloat-equal" #pragma clang diagnostic ignored "-Wold-style-cast" +#pragma clang diagnostic ignored "-Wc++20-extensions" #endif // __clang__ -#define MIGRAPHX_HIP_DEVICE __device__ // We are clipping in down conversion by default #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 // NOLINT @@ -58,21 +58,21 @@ struct float8 { uint8_t data; // default constructor - MIGRAPHX_HIP_DEVICE constexpr float8() = default; + __device__ constexpr float8() = default; // default copy constructor - MIGRAPHX_HIP_DEVICE constexpr float8(const float8& y) = default; + __device__ constexpr float8(const float8& y) = default; struct from_bits_t { }; - static constexpr MIGRAPHX_HIP_DEVICE from_bits_t from_bits() { return from_bits_t(); } + static constexpr __device__ from_bits_t from_bits() { return from_bits_t(); } - MIGRAPHX_HIP_DEVICE explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} + __device__ explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // device specific optimized F8 down-conversion code template - static constexpr MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0) + static __device__ uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0) { uint8_t i8data = 0x00; union @@ -132,20 +132,50 @@ struct float8 #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // NOTE: ON-DEVICE... always optimal bias - explicit constexpr MIGRAPHX_HIP_DEVICE + explicit constexpr __device__ float8(const float v, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, uint32_t rng = 0) { - // runtime branch, use cast_to_f8_from_f32 if want to avoid it - if(rm == migraphx::fp8::rounding_mode::stochastic) - data = cast_to_f8_from_f32(v, rng); + if(__builtin_is_constant_evaluated()) + { + if constexpr(T == migraphx::fp8::f8_type::fp8) + { +#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#else // MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING + } + else + { +#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#else // MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING} + } + } else - data = cast_to_f8_from_f32(v); + { + // runtime branch, use cast_to_f8_from_f32 if want to avoid it + if(rm == migraphx::fp8::rounding_mode::stochastic) + data = cast_to_f8_from_f32(v, rng); + else + data = cast_to_f8_from_f32(v); + } } #else // DEVICE for non-gfx940 using s/w simulation - explicit constexpr MIGRAPHX_HIP_DEVICE + explicit constexpr __device__ float8(const float v, migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, uint32_t rng = 0) @@ -178,64 +208,74 @@ struct float8 #endif // __gfx940___ // Constructor from half - explicit constexpr MIGRAPHX_HIP_DEVICE + explicit constexpr __device__ float8(const _Float16 v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) : float8((float)v, rm, rng) { } // constructor from int - explicit constexpr MIGRAPHX_HIP_DEVICE + explicit constexpr __device__ float8(const int v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) : float8((float)v, rm, rng) { } // constructor from uint - explicit constexpr MIGRAPHX_HIP_DEVICE + explicit constexpr __device__ float8(const uint32_t v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) : float8((float)v, rm, rng) { } // constructor from double - explicit constexpr MIGRAPHX_HIP_DEVICE + explicit constexpr __device__ float8(const double v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) : float8((float)v, rm, rng) { } // constructor from bool - explicit constexpr MIGRAPHX_HIP_DEVICE + explicit constexpr __device__ float8(const bool v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) : float8((float)(v), rm, rng) { } // convert to float -// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) -#if 0 // need constexpr operator(). This version can't be constexpr // NOLINT +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // NOLINT // upcast using device specific intrinsic - inline MIGRAPHX_HIP_DEVICE operator float() const + inline constexpr __device__ operator float() const { - float fval; - uint32_t i32val = static_cast(data); - - // upcast - if constexpr(T == migraphx::fp8::f8_type::fp8) + if(__builtin_is_constant_evaluated()) { - asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + if constexpr(T == migraphx::fp8::f8_type::fp8) + { + return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>( + data); + } // else + return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); } else { - asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); - } + float fval = 0; + uint32_t i32val = static_cast(data); + + // upcast + if constexpr(T == migraphx::fp8::f8_type::fp8) + { + __asm__ volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + } + else + { + __asm__ volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + } - return fval; + return fval; + } } #else // non gfx940 - inline constexpr MIGRAPHX_HIP_DEVICE operator float() const -#endif + inline constexpr __device__ operator float() const { if constexpr(T == migraphx::fp8::f8_type::fp8) { @@ -243,11 +283,12 @@ struct float8 } // else return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); } +#endif - inline constexpr explicit MIGRAPHX_HIP_DEVICE operator bool() const { return not is_zero(); } + inline constexpr explicit __device__ operator bool() const { return not is_zero(); } // check for zero - inline MIGRAPHX_HIP_DEVICE constexpr bool is_zero() const + inline __device__ constexpr bool is_zero() const { if constexpr(FNUZ) { @@ -260,7 +301,7 @@ struct float8 } // check for nan - inline MIGRAPHX_HIP_DEVICE constexpr bool is_nan() const + inline __device__ constexpr bool is_nan() const { if constexpr(FNUZ) { @@ -281,7 +322,7 @@ struct float8 } // check for inf - inline MIGRAPHX_HIP_DEVICE constexpr bool is_inf() const + inline __device__ constexpr bool is_inf() const { if constexpr(FNUZ) { @@ -303,13 +344,13 @@ struct float8 // NOLINTNEXTLINE #define MIGRAPHX_FP8_SHORT_UNARY_OP(unary_op, binary_op) \ - constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float8& rhs) \ + constexpr float8& __device__ operator unary_op(const float8& rhs) \ { \ const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ *this = static_cast(tmp); \ return *this; \ } \ - constexpr float8& MIGRAPHX_HIP_DEVICE operator unary_op(const float& rhs) \ + constexpr float8& __device__ operator unary_op(const float& rhs) \ { \ const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ *this = static_cast(tmp); \ @@ -321,10 +362,10 @@ struct float8 MIGRAPHX_FP8_SHORT_UNARY_OP(+=, +) MIGRAPHX_FP8_SHORT_UNARY_OP(/=, /) - inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(const float8& rhs) = default; - inline MIGRAPHX_HIP_DEVICE constexpr float8& operator=(float8&& rhs) noexcept = default; + inline __device__ constexpr float8& operator=(const float8& rhs) = default; + inline __device__ constexpr float8& operator=(float8&& rhs) noexcept = default; - inline MIGRAPHX_HIP_DEVICE constexpr bool operator==(const float8& rhs) const + inline __device__ constexpr bool operator==(const float8& rhs) const { if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) return false; @@ -333,14 +374,14 @@ struct float8 return false; } - inline MIGRAPHX_HIP_DEVICE constexpr bool operator<(const float8& rhs) const + inline __device__ constexpr bool operator<(const float8& rhs) const { const auto we = static_cast(*this); const auto them = static_cast(rhs); return we < them; } - inline MIGRAPHX_HIP_DEVICE constexpr bool operator>(const float8& rhs) const + inline __device__ constexpr bool operator>(const float8& rhs) const { const auto we = static_cast(*this); const auto them = static_cast(rhs); @@ -355,19 +396,19 @@ using fp8e4m3fnuz = float8; using fp8e5m2fnuz = float8; // NOLINTNEXTLINE -#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \ - inline constexpr U MIGRAPHX_HIP_DEVICE operator binary_op(const T& lhs, const T& rhs) \ - { \ - return U(static_cast(lhs) binary_op static_cast(rhs)); \ +#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \ + inline constexpr U __device__ operator binary_op(const T& lhs, const T& rhs) \ + { \ + return U(static_cast(lhs) binary_op static_cast(rhs)); \ } // NOLINTNEXTLINE -#define MIGRAPHX_FP8_FABS(T) \ - inline constexpr MIGRAPHX_HIP_DEVICE T fabs(T v) \ - { \ - /*NOLINTNEXTLINE*/ \ - v.data = v.data & 0x7f; \ - return v; \ +#define MIGRAPHX_FP8_FABS(T) \ + inline constexpr __device__ T fabs(T v) \ + { \ + /*NOLINTNEXTLINE*/ \ + v.data = v.data & 0x7f; \ + return v; \ } // NOLINTNEXTLINE @@ -394,27 +435,27 @@ class numeric_limits { public: static constexpr bool has_infinity = false; - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz epsilon() + static constexpr __device__ fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); } // NOLINTNEXTLINE - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz quiet_NaN() + static constexpr __device__ fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz max() + static constexpr __device__ fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01 - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz min() + static constexpr __device__ fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fnuz lowest() + static constexpr __device__ fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); } @@ -425,27 +466,21 @@ class numeric_limits { public: static constexpr bool has_infinity = false; - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn epsilon() + static constexpr __device__ fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); } // NOLINTNEXTLINE - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn quiet_NaN() + static constexpr __device__ fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); } - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn max() - { - return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); - } + static constexpr __device__ fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01 - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn min() - { - return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); - } + static constexpr __device__ fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); } - static constexpr MIGRAPHX_HIP_DEVICE fp8e4m3fn lowest() + static constexpr __device__ fp8e4m3fn lowest() { return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); } @@ -456,28 +491,28 @@ class numeric_limits { public: static constexpr bool has_infinity = false; - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz epsilon() + static constexpr __device__ fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz quiet_NaN() // NOLINT + static constexpr __device__ fp8e5m2fnuz quiet_NaN() // NOLINT { return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz max() + static constexpr __device__ fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this distinction. For the floating points we would end up using lowest most of the times. - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz min() + static constexpr __device__ fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); } - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2fnuz lowest() + static constexpr __device__ fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); } @@ -488,36 +523,21 @@ class numeric_limits { public: static constexpr bool has_infinity = true; - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 epsilon() - { - return fp8e5m2(0x34, fp8e5m2::from_bits()); - } + static constexpr __device__ fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); } // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 quiet_NaN() // NOLINT + static constexpr __device__ fp8e5m2 quiet_NaN() // NOLINT { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 max() - { - return fp8e5m2(0x7B, fp8e5m2::from_bits()); - } + static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make // this distinction. For the floating points we would end up using lowest most of the times. - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 min() - { - return fp8e5m2(0x4, fp8e5m2::from_bits()); - } + static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 lowest() - { - return fp8e5m2(0xFB, fp8e5m2::from_bits()); - } + static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } // 7C and FC both are infinity - static constexpr MIGRAPHX_HIP_DEVICE fp8e5m2 infinity() - { - return fp8e5m2(0x7C, fp8e5m2::from_bits()); - } + static constexpr __device__ fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); } }; } // namespace fp8 diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp index 0958bf057f3..c82c9df296a 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp @@ -52,13 +52,14 @@ __device__ void generic_binary_layernorm( block::template run([&](auto, auto r) { auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); using value_type = typename Input1::type; - using vec_value_type = vec_type; - constexpr auto relements = r.template elements(); - auto relements_r = vec_value_type{1.0 / relements}; - auto relements_rsqrt = sqrt(relements_r); + using vec_value_type = vec_type; + constexpr auto relements = r.template elements(); + constexpr auto relements_r = static_cast(1.0 / relements); + auto relements_rsqrt = sqrt(relements_r); auto means = r.reduce(op::sum{}, - make_array(vec_value_type{0}, vec_value_type{0}), + make_array(static_cast(0), + static_cast(0)), [&](auto x) { auto x_out = x * relements_r; // dividing x by sqrt(relements) before squaring allows computing @@ -70,7 +71,7 @@ __device__ void generic_binary_layernorm( auto mean_x = means[0]; auto mean_x2 = means[1]; auto variance = mean_x2 - (mean_x * mean_x); - value_type eps_val = value_type{eps}; + value_type eps_val = static_cast(eps); r.inner([&](auto& y, auto x, auto... xs) { auto m = x - mean_x; diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp index 47997e0111e..6664ae536d5 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp @@ -44,7 +44,7 @@ __device__ void softmax(Input input1, Output output) auto exp_in = r.inner([&](auto x) { return migraphx::exp(x - c); })(input); auto batch_sum = r.reduce(op::sum{}, 0, [](auto x) { return migraphx::convert(x); })(exp_in); - r.inner([&](auto& y, auto x) { y = otype{x / batch_sum}; })(output, exp_in); + r.inner([&](auto& y, auto x) { y = static_cast(x / batch_sum); })(output, exp_in); }); } From 13ef41484b6e76f9f2a3f547e33a1251847dc06c Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 19:53:24 +0000 Subject: [PATCH 054/115] add test for rsqrt and remove old-styple-cast --- .../include/migraphx/kernels/float8.hpp | 1 - test/verify/test_rsqrt.cpp | 20 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 469b1913354..8c8609520a2 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -25,7 +25,6 @@ #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wfloat-equal" -#pragma clang diagnostic ignored "-Wold-style-cast" #pragma clang diagnostic ignored "-Wc++20-extensions" #endif // __clang__ diff --git a/test/verify/test_rsqrt.cpp b/test/verify/test_rsqrt.cpp index 2862196b095..fa131365a93 100644 --- a/test/verify/test_rsqrt.cpp +++ b/test/verify/test_rsqrt.cpp @@ -23,22 +23,26 @@ */ #include "verify_program.hpp" +#include #include #include #include -struct test_rsqrt : verify_program +template +struct test_rsqrt : verify_program> { migraphx::program create_program() const { migraphx::program p; - auto* mm = p.get_main_module(); + auto* mm = p.get_main_module(); + migraphx::shape::type_t dtype = migraphx::shape::get_type(); std::vector input_lens{1, 3, 16, 16}; - migraphx::shape s{migraphx::shape::float_type, input_lens}; + migraphx::shape s{dtype, input_lens}; auto x = mm->add_parameter("x", s); - auto min_val = mm->add_literal(1.0f); - auto max_val = mm->add_literal(std::numeric_limits::max()); - min_val = mm->add_instruction( + auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {1.0}}); + auto max_val = mm->add_literal( + migraphx::literal{migraphx::shape{dtype}, {std::numeric_limits::max()}}); + min_val = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val); max_val = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val); @@ -48,4 +52,6 @@ struct test_rsqrt : verify_program }; }; -// TOOD : Add FP8 test +template struct test_rsqrt; +template struct test_rsqrt; +template struct test_rsqrt; From 8660572319d943dcc1169a7987e28eb68a0df891 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 19:54:01 +0000 Subject: [PATCH 055/115] add comment about c++20 extensions --- src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 8c8609520a2..e30c54e1f4a 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -25,8 +25,8 @@ #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wfloat-equal" -#pragma clang diagnostic ignored "-Wc++20-extensions" -#endif // __clang__ +#pragma clang diagnostic ignored "-Wc++20-extensions" // required for "asm" inside constexpr +#endif // __clang__ // We are clipping in down conversion by default #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 // NOLINT From 6fbd997003cae98032bc09deee6e32e4f42b7822 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 19:59:12 +0000 Subject: [PATCH 056/115] Remove old cast --- .../gpu/kernels/include/migraphx/kernels/float8.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index e30c54e1f4a..95a8249e8ae 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -209,35 +209,35 @@ struct float8 // Constructor from half explicit constexpr __device__ float8(const _Float16 v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) - : float8((float)v, rm, rng) + : float8(static_cast(v), rm, rng) { } // constructor from int explicit constexpr __device__ float8(const int v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) - : float8((float)v, rm, rng) + : float8(static_cast(v), rm, rng) { } // constructor from uint explicit constexpr __device__ float8(const uint32_t v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) - : float8((float)v, rm, rng) + : float8(static_cast(v), rm, rng) { } // constructor from double explicit constexpr __device__ float8(const double v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) - : float8((float)v, rm, rng) + : float8(static_cast(v), rm, rng) { } // constructor from bool explicit constexpr __device__ float8(const bool v, rounding_mode rm = rounding_mode::standard, uint32_t rng = 0) - : float8((float)(v), rm, rng) + : float8(static_cast(v), rm, rng) { } // convert to float From 2acd265b9969abfa27e51ac5ed22ef422b6b9648 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 20:02:46 +0000 Subject: [PATCH 057/115] Remove DPP --- src/targets/gpu/jit/reduce.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index d3a51153fc3..e001b5ba510 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -177,13 +177,6 @@ struct simple_reduce_compiler : compiler {"algo", algo}, {"transformers", make_transformer_args(vec)}, {"preamble", v.get("preamble", std::string{})}}); - // disable DPP for FP8 for now,, TODO: need to disable for Any FP8 types - if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { - return s.type() == migraphx::shape::fp8e4m3fnuz_type; - })) - { - options.params += "-DMIGRAPHX_HAS_DPP=0 "; - } options.params += "-Wno-float-equal"; return compile_hip_code_object(src, options); } From 836e201e49c67bd277b84ef4ed911bcce7738f69 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 21:13:51 +0000 Subject: [PATCH 058/115] Remove MIN max overloads --- .../gpu/kernels/include/migraphx/kernels/float8.hpp | 2 -- .../gpu/kernels/include/migraphx/kernels/math.hpp | 11 ----------- 2 files changed, 13 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 95a8249e8ae..5622f73ad1b 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -419,8 +419,6 @@ using fp8e5m2fnuz = float8; MIGRAPHX_FP8_BINARY_OP(==, T, bool) \ MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \ MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \ - MIGRAPHX_FP8_BINARY_OP(>, T, bool) \ - MIGRAPHX_FP8_BINARY_OP(<, T, bool) \ MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \ MIGRAPHX_FP8_FABS(T) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index 50e815a6c8a..e2acad66826 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -81,14 +81,6 @@ constexpr T as_float(T x) auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \ migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...))) -// NOLINTNEXTLINE -#define MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(name, fname) \ - inline auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, migraphx::fp8::fp8e4m3fnuz y) \ - -> migraphx::fp8::fp8e4m3fnuz \ - { \ - return migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(y))); \ - } - // Template with two overloads for math functions, one for half2 type and one for more generic // vectorization where N is 4 or another even number. @@ -239,9 +231,6 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin) -MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(max, ::max) -MIGRAPHX_DEVICE_MATH_BINARY_FOR_FP8(min, ::min) - template ())> constexpr auto max(const T& a, const T& b) { From f9542d5b86aed60ad8ffca4364f6fe40d83bd31a Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 20 Nov 2023 23:03:20 +0000 Subject: [PATCH 059/115] Put numeric_max and numeeric lowest into float8 --- .../include/migraphx/kernels/float8.hpp | 19 +++++++++++++++++++ .../include/migraphx/kernels/float8_impl.hpp | 19 +++---------------- .../kernels/include/migraphx/kernels/math.hpp | 1 + .../include/migraphx/kernels/type_traits.hpp | 9 +-------- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 5622f73ad1b..9c8535129fb 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -33,6 +33,7 @@ #include #include +#include namespace migraphx { namespace fp8 { @@ -538,6 +539,24 @@ class numeric_limits }; } // namespace fp8 + +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_MIN_MAX(T) \ + template <> \ + constexpr T numeric_max() \ + { \ + return fp8::numeric_limits::max(); \ + } \ + template <> \ + constexpr T numeric_lowest() \ + { \ + return fp8::numeric_limits::lowest(); \ + } + +MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fnuz); +MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2fnuz); +MIGRAPHX_FP8_MIN_MAX(fp8::fp8e4m3fn); +MIGRAPHX_FP8_MIN_MAX(fp8::fp8e5m2); } // namespace migraphx // ================================================================================================= #if defined(__clang__) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp index 95477c1b120..2eca5ed4af1 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp @@ -23,26 +23,13 @@ #ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP #define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP #include +#include #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wreserved-identifier" #endif namespace migraphx { -namespace detail { -template -struct conditional -{ - using type = T; -}; - -template -struct conditional -{ - using type = F; -}; - -} // namespace detail namespace fp8 { namespace impl { @@ -58,7 +45,7 @@ __device__ constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng static_assert(is_float or is_half, "Only float can be cast to f8"); const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10; - typename detail::conditional::type x; + typename migraphx::conditional_t x; if constexpr(sizeof(T) == 4) x = migraphx::bit_cast(f_x); @@ -304,7 +291,7 @@ __device__ constexpr T cast_from_f8(uint8_t x) else if(Wm == 3 and (x == 0x7F or x == 0xFF)) return f_nan; } - typename detail::conditional::type retval; + typename migraphx::conditional_t retval; const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index e2acad66826..0b54b26bf77 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -25,6 +25,7 @@ #define MIGRAPHX_GUARD_KERNELS_MATH_HPP #include +#include #include #include #include diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp index 166bee2e57a..890e55837a3 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp @@ -26,7 +26,6 @@ #include #include -#include namespace migraphx { @@ -231,8 +230,7 @@ constexpr unsigned long int_max(unsigned long n) template {} or is_floating_point{} or - is_same{} or - is_same{})> + is_same{})> constexpr T numeric_max() { if constexpr(is_integral{}) @@ -248,9 +246,6 @@ constexpr T numeric_max() return __FLT_MAX__; else if constexpr(is_same{}) return __FLT16_MAX__; - // TODO: Do it generically for all fp8 types - else if constexpr(is_same{}) - return migraphx::fp8::numeric_limits::max(); else return 0; } @@ -265,8 +260,6 @@ constexpr T numeric_lowest() else return -numeric_max() - 1; } - else if constexpr(is_same{}) - return migraphx::fp8::numeric_limits::lowest(); else { return -numeric_max(); From 480288f82a747e35569d6f620896fc29b7eb0877 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 21 Nov 2023 00:01:40 +0000 Subject: [PATCH 060/115] use void for highest to match template candidates --- src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp | 2 +- src/targets/gpu/kernels/include/migraphx/kernels/math.hpp | 2 +- src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index 9c8535129fb..87cdfcd0a15 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -32,8 +32,8 @@ #define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 // NOLINT #include -#include #include +#include namespace migraphx { namespace fp8 { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index 0b54b26bf77..81b532fddf5 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -25,11 +25,11 @@ #define MIGRAPHX_GUARD_KERNELS_MATH_HPP #include -#include #include #include #include #include +#include namespace migraphx { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp index 898ce637e6a..f70e4933347 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp @@ -118,7 +118,7 @@ struct highest template constexpr operator T() const { - return numeric_max>(); + return numeric_max, void>(); } }; } // namespace migraphx From a6c57726106b1a24dab8e9ee9a028532e0ffc5af Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 21 Nov 2023 00:13:37 +0000 Subject: [PATCH 061/115] add float8 for tensorview --- src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp index 56a8b6f7c35..3d608757ad0 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp @@ -27,6 +27,7 @@ #include #include #include +#include namespace migraphx { From 3aa465fd3f2117701d1e61fe947914848618bb2b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 15:47:04 +0000 Subject: [PATCH 062/115] compiles all right --- src/targets/gpu/CMakeLists.txt | 9 ++ src/targets/gpu/gemm_impl.cpp | 113 +++++++++++++++--- .../gpu/include/migraphx/gpu/rocblas.hpp | 2 + src/targets/gpu/lowering.cpp | 32 +++++ src/targets/gpu/rocblas.cpp | 9 ++ test/verify/test_convert.cpp | 18 +-- 6 files changed, 159 insertions(+), 24 deletions(-) diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index e54c5050512..9c3f51687db 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -253,6 +253,8 @@ check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCAT check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) # Beta API for automated GEMM tuning check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API) +# rocblas FP8 API +check_library_exists(roc::rocblas "rocblas_gemm_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API) set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") @@ -282,6 +284,13 @@ else() message(STATUS "rocBLAS does not have User Tuning Beta API") endif() +if(HAS_ROCBLAS_FP8_BETA_API) + target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS) + message(STATUA "MIGraphX is using BETA API of rocBLAS for FP8 computations") +else() + message(STATUS "rocBLAS does not have FP8 BETA API") +endif() + target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) if(MIGRAPHX_USE_COMPOSABLEKERNEL) diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 4495e21ecac..f43b02b45fa 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -23,10 +23,12 @@ */ #include +#include #include #include #include #include +#include using microseconds = std::chrono::duration; @@ -46,7 +48,7 @@ rocblas_datatype get_type(shape::type_t type) case shape::uint8_type: return rocblas_datatype_u8_r; case shape::int32_type: return rocblas_datatype_i32_r; case shape::uint32_type: return rocblas_datatype_u32_r; - case shape::fp8e4m3fnuz_type: + case shape::fp8e4m3fnuz_type: return rocblas_datatype_f8_r; case shape::tuple_type: case shape::bool_type: case shape::uint16_type: @@ -217,23 +219,50 @@ struct gemm_impl void run(context& ctx, const std::vector& input_args, int32_t solution_idx = 0) const { - if(strided_batched) + if(rocblas_fp8_available() and + std::any_of(input_args.begin(), input_args.end(), [](const auto i) { + return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type; + })) { - auto common_args = create_strided_batched_args_common(ctx, input_args); - rocblas_invoke(&rocblas_gemm_strided_batched_ex, - common_args, - rocblas_gemm_algo_solution_index, - solution_idx, - gemm_flags); + if(strided_batched) + { + auto common_args = create_strided_batched_args_common_fp8(ctx, input_args); + rocblas_invoke(&rocblas_gemm_strided_batched_ex3, + common_args, + rocblas_gemm_algo_solution_index, + solution_idx, + gemm_flags); + } + else + { + auto common_args = create_gemm_ex_args_common_fp8(ctx, input_args); + rocblas_invoke(&rocblas_gemm_ex3, + common_args, + rocblas_gemm_algo_solution_index, + solution_idx, + gemm_flags); + } } else { - auto common_args = create_gemm_ex_args_common(ctx, input_args); - rocblas_invoke(&rocblas_gemm_ex, - common_args, - rocblas_gemm_algo_solution_index, - solution_idx, - gemm_flags); + if(strided_batched) + { + auto common_args = create_strided_batched_args_common(ctx, input_args); + rocblas_invoke(&rocblas_gemm_strided_batched_ex, + common_args, + rocblas_gemm_algo_solution_index, + solution_idx, + gemm_flags); + } + else + { + auto common_args = create_gemm_ex_args_common(ctx, input_args); + rocblas_invoke(&rocblas_gemm_ex, + common_args, + rocblas_gemm_algo_solution_index, + solution_idx, + gemm_flags); + } } } @@ -331,6 +360,36 @@ struct gemm_impl num_matrices, compute_type); } + auto create_strided_batched_args_common_fp8(context& ctx, + const std::vector& args) const + { + return pack(ctx.get_stream().get_rocblas(), + transb ? rocblas_operation_transpose : rocblas_operation_none, + transa ? rocblas_operation_transpose : rocblas_operation_none, + n, + m, + k, + get_alpha(), + args[1].data(), + arg_type, + ldb, + b_stride, + args[0].data(), + arg_type, + lda, + a_stride, + get_beta(), + args[2].data(), + output_type, + ldc, + c_stride, + is_3inputs ? args[3].data() : args[2].data(), + output_type, + ldd, + d_stride, + num_matrices, + rocblas_compute_type_f8_f8_f32); + } /** * Helper method to create that subset of a long rocBLAS argument list that is common @@ -366,6 +425,30 @@ struct gemm_impl ldd, compute_type); } + auto create_gemm_ex_args_common_fp8(context& ctx, const std::vector& args) const + { + return pack(ctx.get_stream().get_rocblas(), + transb ? rocblas_operation_transpose : rocblas_operation_none, + transa ? rocblas_operation_transpose : rocblas_operation_none, + n, + m, + k, + get_alpha(), + args[1].data(), + arg_type, + ldb, + args[0].data(), + arg_type, + lda, + get_beta(), + args[2].data(), + output_type, + ldc, + is_3inputs ? args[3].data() : args[2].data(), + output_type, + ldd, + rocblas_compute_type_f8_f8_f32); + } #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API /** * Find best rocBLAS solution: Get list of solutions and try them all, returning the index @@ -481,8 +564,8 @@ struct gemm_impl rocblas_int b_stride = 0; rocblas_int c_stride = 0; rocblas_int d_stride = 0; - rocblas_datatype compute_type = rocblas_datatype_f32_r; rocblas_datatype arg_type = rocblas_datatype_f32_r; + rocblas_datatype compute_type = rocblas_datatype_f32_r; rocblas_datatype output_type = rocblas_datatype_f32_r; bool strided_batched = true; bool is_3inputs = true; diff --git a/src/targets/gpu/include/migraphx/gpu/rocblas.hpp b/src/targets/gpu/include/migraphx/gpu/rocblas.hpp index b103dac997e..e72666e25ae 100644 --- a/src/targets/gpu/include/migraphx/gpu/rocblas.hpp +++ b/src/targets/gpu/include/migraphx/gpu/rocblas.hpp @@ -40,6 +40,8 @@ struct context; MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag(); +MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available(); + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index ea0e29f8853..bd87ba2800e 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -220,12 +220,44 @@ struct miopen_apply return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}})); } + instruction_ref convert_fp8_to_fp32(instruction_ref ins) + { + std::vector fp8_inputs = ins->inputs(); + std::vector fp32_inputs; + for(const auto& i : fp8_inputs) + { + fp32_inputs.push_back(mod->insert_instruction( + ins, + migraphx::make_op( + "convert", + {{"target_type", migraphx::to_value(migraphx::shape::type_t::float_type)}}), + i)); + } + auto fp32_ins = mod->insert_instruction(ins, ins->get_operator(), {fp32_inputs}); + auto fp8_ins = mod->insert_instruction( + ins, + migraphx::make_op( + "convert", + {{"target_type", migraphx::to_value(migraphx::shape::type_t::fp8e4m3fnuz_type)}}), + fp32_ins); + mod->replace_instruction(ins, fp8_ins); + return fp32_ins; + } + template void add_gemm_op(const std::string& name) { apply_map.emplace(name, [=](instruction_ref ins) { std::vector refs = ins->inputs(); assert(refs.size() == 2); + if(not rocblas_fp8_available() and + std::any_of(refs.begin(), refs.end(), [](const auto i) { + return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type; + })) + { + // replace fp8 ins with fp32 ins + ins = convert_fp8_to_fp32(ins); + } auto output = insert_allocation(ins, ins->get_shape()); refs.push_back(output); return mod->replace_instruction(ins, rocblas_gemm{Op{}, 1, 0, compute_fp32}, refs); diff --git a/src/targets/gpu/rocblas.cpp b/src/targets/gpu/rocblas.cpp index 0a0faa4719f..d73813c2283 100644 --- a/src/targets/gpu/rocblas.cpp +++ b/src/targets/gpu/rocblas.cpp @@ -53,6 +53,15 @@ bool get_compute_fp32_flag() return (starts_with(device_name, "gfx9") and device_name >= "gfx908"); } +bool rocblas_fp8_available() { +#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API + return false; +#else + const auto device_name = trim(split_string(get_device_name(), ':').front()); + return (starts_with(device_name, "gfx9") and device_name >= "gfx940"); +#endif +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/test/verify/test_convert.cpp b/test/verify/test_convert.cpp index d50f4146f0d..02a82e7a2f2 100644 --- a/test/verify/test_convert.cpp +++ b/test/verify/test_convert.cpp @@ -29,26 +29,26 @@ #include -struct test_convert : verify_program +template +struct test_convert : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape sa{migraphx::shape::int8_type, {8, 24}}; - migraphx::shape sb{migraphx::shape::int8_type, {24, 6}}; + migraphx::shape sa{From, {8, 24}}; + migraphx::shape sb{From, {24, 6}}; auto pa = mm->add_parameter("a", sa); auto pb = mm->add_parameter("b", sb); auto ia = mm->add_instruction( - migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), - pa); + migraphx::make_op("convert", {{"target_type", migraphx::to_value(To)}}), pa); auto ib = mm->add_instruction( - migraphx::make_op("convert", - {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), - pb); + migraphx::make_op("convert", {{"target_type", migraphx::to_value(To)}}), pb); mm->add_instruction(migraphx::make_op("dot"), ia, ib); return p; }; }; + +template struct test_convert; +template struct test_convert; From 037205c535b9d0065d2045c8ed61f91e07ebe506 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 16:27:36 +0000 Subject: [PATCH 063/115] Works now --- src/targets/gpu/CMakeLists.txt | 4 ++-- src/targets/gpu/gemm_impl.cpp | 8 ++++---- test/verify/gemm_2args_bmv.cpp | 11 ++++++++--- test/verify/gemm_2args_mm_1.cpp | 10 +++++++--- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 9c3f51687db..d4f600394cf 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -286,9 +286,9 @@ endif() if(HAS_ROCBLAS_FP8_BETA_API) target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS) - message(STATUA "MIGraphX is using BETA API of rocBLAS for FP8 computations") + message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations") else() - message(STATUS "rocBLAS does not have FP8 BETA API") + message(STATUS "rocBLAS does not have Fp8 Beta API") endif() target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index f43b02b45fa..b27fabc2bc8 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -229,7 +229,7 @@ struct gemm_impl auto common_args = create_strided_batched_args_common_fp8(ctx, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex3, common_args, - rocblas_gemm_algo_solution_index, + rocblas_gemm_algo_standard, solution_idx, gemm_flags); } @@ -238,7 +238,7 @@ struct gemm_impl auto common_args = create_gemm_ex_args_common_fp8(ctx, input_args); rocblas_invoke(&rocblas_gemm_ex3, common_args, - rocblas_gemm_algo_solution_index, + rocblas_gemm_algo_standard, solution_idx, gemm_flags); } @@ -388,7 +388,7 @@ struct gemm_impl ldd, d_stride, num_matrices, - rocblas_compute_type_f8_f8_f32); + rocblas_compute_type_f32); } /** @@ -447,7 +447,7 @@ struct gemm_impl is_3inputs ? args[3].data() : args[2].data(), output_type, ldd, - rocblas_compute_type_f8_f8_f32); + rocblas_compute_type_f32); } #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API /** diff --git a/test/verify/gemm_2args_bmv.cpp b/test/verify/gemm_2args_bmv.cpp index 896f2df64d6..d192de96bab 100644 --- a/test/verify/gemm_2args_bmv.cpp +++ b/test/verify/gemm_2args_bmv.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_2args_bmv : verify_program +template +struct gemm_2args_bmv : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; + migraphx::shape m1_shape{DType, {2, 3, 3, 5}}; + migraphx::shape m2_shape{DType, {5}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); @@ -46,3 +47,7 @@ struct gemm_2args_bmv : verify_program return p; } }; + +template struct gemm_2args_bmv; +template struct gemm_2args_bmv; + diff --git a/test/verify/gemm_2args_mm_1.cpp b/test/verify/gemm_2args_mm_1.cpp index 204f957ce7a..6df8d42dd0f 100644 --- a/test/verify/gemm_2args_mm_1.cpp +++ b/test/verify/gemm_2args_mm_1.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_2args_mm_1 : verify_program +template +struct gemm_2args_mm_1 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; + migraphx::shape m1_shape{DType, {2, 2, 3}}; + migraphx::shape m2_shape{DType, {1, 3, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto bl2 = @@ -45,3 +46,6 @@ struct gemm_2args_mm_1 : verify_program return p; } }; + +template struct gemm_2args_mm_1; +template struct gemm_2args_mm_1; From 87548b5d02134c74e9350aa769824dad84ed88f0 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 16:33:04 +0000 Subject: [PATCH 064/115] add ifdef to compile --- src/targets/gpu/gemm_impl.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index b27fabc2bc8..6d30df64718 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -219,6 +219,7 @@ struct gemm_impl void run(context& ctx, const std::vector& input_args, int32_t solution_idx = 0) const { +#ifdef MIGRAPHX_USE_ROCBLAS_FP8_API if(rocblas_fp8_available() and std::any_of(input_args.begin(), input_args.end(), [](const auto i) { return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type; @@ -244,6 +245,7 @@ struct gemm_impl } } else +#endif { if(strided_batched) { From d473b80283b8030108fcb997e409697d47fb103f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 16:42:13 +0000 Subject: [PATCH 065/115] add tests and fix cmake --- src/targets/gpu/CMakeLists.txt | 2 +- test/verify/gemm_2args_mm_2.cpp | 11 ++++++++--- test/verify/gemm_2args_mm_3.cpp | 11 ++++++++--- test/verify/gemm_2args_mm_4.cpp | 11 ++++++++--- test/verify/gemm_2args_mm_5.cpp | 10 +++++++--- test/verify/gemm_2args_mm_6.cpp | 11 ++++++++--- test/verify/gemm_2args_mm_7.cpp | 10 +++++++--- test/verify/gemm_2args_mm_8.cpp | 10 +++++++--- test/verify/gemm_2args_mv.cpp | 10 +++++++--- test/verify/gemm_2args_vbm.cpp | 10 +++++++--- 10 files changed, 68 insertions(+), 28 deletions(-) diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index d4f600394cf..508400d8fac 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -254,7 +254,7 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_ # Beta API for automated GEMM tuning check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API) # rocblas FP8 API -check_library_exists(roc::rocblas "rocblas_gemm_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API) +check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API) set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") diff --git a/test/verify/gemm_2args_mm_2.cpp b/test/verify/gemm_2args_mm_2.cpp index 7e2405abc4e..448a76c7c99 100644 --- a/test/verify/gemm_2args_mm_2.cpp +++ b/test/verify/gemm_2args_mm_2.cpp @@ -22,19 +22,21 @@ * THE SOFTWARE. */ +#include "migraphx/shape.hpp" #include "verify_program.hpp" #include #include #include -struct gemm_2args_mm_2 : verify_program +template +struct gemm_2args_mm_2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}}; + migraphx::shape m1_shape{DType, {2, 2, 3}}; + migraphx::shape m2_shape{DType, {3, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto bl2 = @@ -45,3 +47,6 @@ struct gemm_2args_mm_2 : verify_program return p; } }; + +template struct gemm_2args_mm_2; +template struct gemm_2args_mm_2; diff --git a/test/verify/gemm_2args_mm_3.cpp b/test/verify/gemm_2args_mm_3.cpp index d0edcc9b2a3..668a55f1c88 100644 --- a/test/verify/gemm_2args_mm_3.cpp +++ b/test/verify/gemm_2args_mm_3.cpp @@ -22,19 +22,21 @@ * THE SOFTWARE. */ +#include "migraphx/shape.hpp" #include "verify_program.hpp" #include #include #include -struct gemm_2args_mm_3 : verify_program +template +struct gemm_2args_mm_3 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; + migraphx::shape m1_shape{DType, {1, 2, 3}}; + migraphx::shape m2_shape{DType, {3, 3, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), l1); @@ -45,3 +47,6 @@ struct gemm_2args_mm_3 : verify_program return p; } }; + +template struct gemm_2args_mm_3; +template struct gemm_2args_mm_3; diff --git a/test/verify/gemm_2args_mm_4.cpp b/test/verify/gemm_2args_mm_4.cpp index af04b896c68..c0e02d8dfa6 100644 --- a/test/verify/gemm_2args_mm_4.cpp +++ b/test/verify/gemm_2args_mm_4.cpp @@ -22,19 +22,21 @@ * THE SOFTWARE. */ +#include "migraphx/shape.hpp" #include "verify_program.hpp" #include #include #include -struct gemm_2args_mm_4 : verify_program +template +struct gemm_2args_mm_4 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; + migraphx::shape m1_shape{DType, {2, 3}}; + migraphx::shape m2_shape{DType, {3, 3, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto bl1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), l1); @@ -45,3 +47,6 @@ struct gemm_2args_mm_4 : verify_program return p; } }; + +template struct gemm_2args_mm_4; +template struct gemm_2args_mm_4; diff --git a/test/verify/gemm_2args_mm_5.cpp b/test/verify/gemm_2args_mm_5.cpp index 93316ee30ea..adb83929afd 100644 --- a/test/verify/gemm_2args_mm_5.cpp +++ b/test/verify/gemm_2args_mm_5.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_2args_mm_5 : verify_program +template +struct gemm_2args_mm_5 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; + migraphx::shape m1_shape{DType, {2, 1, 2, 3}}; + migraphx::shape m2_shape{DType, {2, 3, 3, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto bl1 = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1); @@ -45,3 +46,6 @@ struct gemm_2args_mm_5 : verify_program return p; } }; + +template struct gemm_2args_mm_5; +template struct gemm_2args_mm_5; diff --git a/test/verify/gemm_2args_mm_6.cpp b/test/verify/gemm_2args_mm_6.cpp index e76b2170e4f..d6f4c51b4b0 100644 --- a/test/verify/gemm_2args_mm_6.cpp +++ b/test/verify/gemm_2args_mm_6.cpp @@ -27,14 +27,16 @@ #include #include -struct gemm_2args_mm_6 : verify_program +template + +struct gemm_2args_mm_6 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}}; + migraphx::shape m1_shape{DType, {2, 1, 2, 3}}; + migraphx::shape m2_shape{DType, {1, 3, 3, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto bl1 = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1); @@ -47,3 +49,6 @@ struct gemm_2args_mm_6 : verify_program return p; } }; + +template struct gemm_2args_mm_6; +template struct gemm_2args_mm_6; diff --git a/test/verify/gemm_2args_mm_7.cpp b/test/verify/gemm_2args_mm_7.cpp index 4cfa8bd67f3..3dcb5d1ecef 100644 --- a/test/verify/gemm_2args_mm_7.cpp +++ b/test/verify/gemm_2args_mm_7.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_2args_mm_7 : verify_program +template +struct gemm_2args_mm_7 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; + migraphx::shape m1_shape{DType, {2, 3}}; + migraphx::shape m2_shape{DType, {2, 3, 3, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto bl1 = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 2, 3}}}), l1); @@ -45,3 +46,6 @@ struct gemm_2args_mm_7 : verify_program return p; } }; + +template struct gemm_2args_mm_7; +template struct gemm_2args_mm_7; diff --git a/test/verify/gemm_2args_mm_8.cpp b/test/verify/gemm_2args_mm_8.cpp index 027643f2f52..b4da6d40990 100644 --- a/test/verify/gemm_2args_mm_8.cpp +++ b/test/verify/gemm_2args_mm_8.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_2args_mm_8 : verify_program +template +struct gemm_2args_mm_8 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::float_type, {2, 128, 32}, {4096, 1, 128}}; - migraphx::shape b_shape{migraphx::shape::float_type, {32, 32}}; + migraphx::shape a_shape{DType, {2, 128, 32}, {4096, 1, 128}}; + migraphx::shape b_shape{DType, {32, 32}}; auto a = mm->add_parameter("a", a_shape); auto b = mm->add_parameter("b", b_shape); auto bb = mm->add_instruction( @@ -45,3 +46,6 @@ struct gemm_2args_mm_8 : verify_program return p; } }; + +template struct gemm_2args_mm_8; +template struct gemm_2args_mm_8; diff --git a/test/verify/gemm_2args_mv.cpp b/test/verify/gemm_2args_mv.cpp index ae81cd39724..773ec758a13 100644 --- a/test/verify/gemm_2args_mv.cpp +++ b/test/verify/gemm_2args_mv.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_2args_mv : verify_program +template +struct gemm_2args_mv : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; + migraphx::shape m1_shape{DType, {3, 5}}; + migraphx::shape m2_shape{DType, {5}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto ul2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l2); @@ -44,3 +45,6 @@ struct gemm_2args_mv : verify_program return p; } }; + +template struct gemm_2args_mv; +template struct gemm_2args_mv; diff --git a/test/verify/gemm_2args_vbm.cpp b/test/verify/gemm_2args_vbm.cpp index e1f24d5e644..f525e74efc5 100644 --- a/test/verify/gemm_2args_vbm.cpp +++ b/test/verify/gemm_2args_vbm.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_2args_vbm : verify_program +template +struct gemm_2args_vbm : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}}; + migraphx::shape m1_shape{DType, {5}}; + migraphx::shape m2_shape{DType, {2, 2, 5, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); auto bul1 = mm->add_instruction( @@ -48,3 +49,6 @@ struct gemm_2args_vbm : verify_program return p; } }; + +template struct gemm_2args_vbm; +template struct gemm_2args_vbm; From 4604f2e17bca2a764fe70449b05fb8a6818196c7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 17:40:51 +0000 Subject: [PATCH 066/115] add tests --- .../gpu/kernels/include/migraphx/kernels/math.hpp | 2 +- test/verify/gemm_2args_vm.cpp | 10 +++++++--- test/verify/gemm_2args_vv.cpp | 10 +++++++--- test/verify/gemm_add.cpp | 13 +++++++++---- test/verify/gemm_add_broadcast1.cpp | 13 +++++++++---- test/verify/gemm_add_broadcast2.cpp | 13 +++++++++---- test/verify/gemm_literal.cpp | 10 +++++++--- test/verify/gemm_multi_3args.cpp | 12 ++++++++---- test/verify/gemm_multi_3args_alpha0.cpp | 13 +++++++++---- test/verify/gemm_multi_3args_beta0.cpp | 12 ++++++++---- test/verify/gemm_multi_3args_c25.cpp | 13 +++++++++---- test/verify/gemm_multi_dim_2.cpp | 10 +++++++--- test/verify/gemm_multi_dim_2_3.cpp | 10 +++++++--- test/verify/gemm_multi_transpose.cpp | 10 +++++++--- test/verify/test_gemm.cpp | 11 +++++++---- test/verify/test_gemm_copy.cpp | 12 ++++++++---- test/verify/test_gemm_ex.cpp | 9 ++++++--- test/verify/test_gemm_transposea.cpp | 10 +++++++--- test/verify/test_gemm_transposea_ex.cpp | 10 +++++++--- test/verify/test_gemm_transposeab.cpp | 10 +++++++--- test/verify/test_gemm_transposeb.cpp | 10 +++++++--- test/verify/test_gemm_transposeb_ex.cpp | 11 +++++++---- test/verify/test_mul_dot_a.cpp | 13 ++++++++----- test/verify/test_mul_dot_b.cpp | 14 +++++++++----- test/verify/test_unbatched_gemm_1.cpp | 13 +++++++++---- test/verify/test_unbatched_gemm_2.cpp | 11 ++++++++--- 26 files changed, 194 insertions(+), 91 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index 81b532fddf5..45f816d29b2 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where) template constexpr auto convert(U v) { - return vec_transform(v)([](auto x) -> T { return x; }); + return vec_transform(v)([](auto x) { return static_cast(x); }); } } // namespace migraphx diff --git a/test/verify/gemm_2args_vm.cpp b/test/verify/gemm_2args_vm.cpp index f36714c893d..067ebc29d87 100644 --- a/test/verify/gemm_2args_vm.cpp +++ b/test/verify/gemm_2args_vm.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_2args_vm : verify_program +template +struct gemm_2args_vm : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}}; + migraphx::shape m1_shape{DType, {5}}; + migraphx::shape m2_shape{DType, {5, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); auto l2 = mm->add_parameter("2", m2_shape); @@ -45,3 +46,6 @@ struct gemm_2args_vm : verify_program return p; } }; + +template struct gemm_2args_vm; +template struct gemm_2args_vm; diff --git a/test/verify/gemm_2args_vv.cpp b/test/verify/gemm_2args_vv.cpp index 5def7fd377c..a42a154615d 100644 --- a/test/verify/gemm_2args_vv.cpp +++ b/test/verify/gemm_2args_vv.cpp @@ -28,14 +28,15 @@ #include #include -struct gemm_2args_vv : verify_program +template +struct gemm_2args_vv : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {8}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {8}}; + migraphx::shape m1_shape{DType, {8}}; + migraphx::shape m2_shape{DType, {8}}; auto l1 = mm->add_parameter("1", m1_shape); auto ul1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l1); auto l2 = mm->add_parameter("2", m2_shape); @@ -48,3 +49,6 @@ struct gemm_2args_vv : verify_program return p; } }; + +template struct gemm_2args_vv; +template struct gemm_2args_vv; diff --git a/test/verify/gemm_add.cpp b/test/verify/gemm_add.cpp index d5624d25f7c..041ddc2c94f 100644 --- a/test/verify/gemm_add.cpp +++ b/test/verify/gemm_add.cpp @@ -27,15 +27,17 @@ #include #include #include -struct gemm_add : verify_program + +template +struct gemm_add : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; + migraphx::shape m1_shape{DType, {1, 2, 3}}; + migraphx::shape m2_shape{DType, {1, 3, 4}}; + migraphx::shape m3_shape{DType, {1, 2, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto l3 = mm->add_parameter("3", m3_shape); @@ -45,3 +47,6 @@ struct gemm_add : verify_program return p; } }; + +template struct gemm_add; +template struct gemm_add; diff --git a/test/verify/gemm_add_broadcast1.cpp b/test/verify/gemm_add_broadcast1.cpp index 93840d991ff..f153c4f9f69 100644 --- a/test/verify/gemm_add_broadcast1.cpp +++ b/test/verify/gemm_add_broadcast1.cpp @@ -27,15 +27,17 @@ #include #include #include -struct gemm_add_broadcast1 : verify_program + +template +struct gemm_add_broadcast1 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {1, 1, 4}}; + migraphx::shape m1_shape{DType, {1, 2, 3}}; + migraphx::shape m2_shape{DType, {1, 3, 4}}; + migraphx::shape m3_shape{DType, {1, 1, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto l3 = mm->add_parameter("3", m3_shape); @@ -47,3 +49,6 @@ struct gemm_add_broadcast1 : verify_program return p; } }; + +template struct gemm_add_broadcast1; +template struct gemm_add_broadcast1; diff --git a/test/verify/gemm_add_broadcast2.cpp b/test/verify/gemm_add_broadcast2.cpp index c42a5ece843..5a7ff9bbe55 100644 --- a/test/verify/gemm_add_broadcast2.cpp +++ b/test/verify/gemm_add_broadcast2.cpp @@ -27,15 +27,17 @@ #include #include #include -struct gemm_add_broadcast2 : verify_program + +template +struct gemm_add_broadcast2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 1}}; + migraphx::shape m1_shape{DType, {1, 2, 3}}; + migraphx::shape m2_shape{DType, {1, 3, 4}}; + migraphx::shape m3_shape{DType, {1, 2, 1}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto l3 = mm->add_parameter("3", m3_shape); @@ -47,3 +49,6 @@ struct gemm_add_broadcast2 : verify_program return p; } }; + +template struct gemm_add_broadcast2; +template struct gemm_add_broadcast2; diff --git a/test/verify/gemm_literal.cpp b/test/verify/gemm_literal.cpp index 3ea52af477b..41ff6cf62dc 100644 --- a/test/verify/gemm_literal.cpp +++ b/test/verify/gemm_literal.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_literal : verify_program +template +struct gemm_literal : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::float_type, {2, 4}}; - migraphx::shape b_shape{migraphx::shape::float_type, {4, 4}}; + migraphx::shape a_shape{DType, {2, 4}}; + migraphx::shape b_shape{DType, {4, 4}}; auto a = mm->add_literal(migraphx::generate_literal(a_shape)); auto b = mm->add_parameter("b", b_shape); @@ -43,3 +44,6 @@ struct gemm_literal : verify_program return p; } }; + +template struct gemm_literal; +template struct gemm_literal; diff --git a/test/verify/gemm_multi_3args.cpp b/test/verify/gemm_multi_3args.cpp index 383354326e0..1b5af2654ba 100644 --- a/test/verify/gemm_multi_3args.cpp +++ b/test/verify/gemm_multi_3args.cpp @@ -28,15 +28,16 @@ #include #include -struct gemm_multi_3args : verify_program +template +struct gemm_multi_3args : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}}; + migraphx::shape m1_shape{DType, {2, 3, 2, 3}}; + migraphx::shape m2_shape{DType, {2, 3, 3, 2}}; + migraphx::shape m3_shape{DType, {2, 3, 2, 2}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); @@ -47,3 +48,6 @@ struct gemm_multi_3args : verify_program return p; } }; + +template struct gemm_multi_3args; +template struct gemm_multi_3args; diff --git a/test/verify/gemm_multi_3args_alpha0.cpp b/test/verify/gemm_multi_3args_alpha0.cpp index 7c993ad4247..638ec4adfc7 100644 --- a/test/verify/gemm_multi_3args_alpha0.cpp +++ b/test/verify/gemm_multi_3args_alpha0.cpp @@ -27,15 +27,17 @@ #include #include #include -struct gemm_multi_3args_alpha0 : verify_program + +template +struct gemm_multi_3args_alpha0 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; + migraphx::shape m1_shape{DType, {1, 2, 3}}; + migraphx::shape m2_shape{DType, {1, 3, 4}}; + migraphx::shape m3_shape{DType, {1, 2, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto l3 = mm->add_parameter("3", m3_shape); @@ -46,3 +48,6 @@ struct gemm_multi_3args_alpha0 : verify_program return p; } }; + +template struct gemm_multi_3args_alpha0; +template struct gemm_multi_3args_alpha0; diff --git a/test/verify/gemm_multi_3args_beta0.cpp b/test/verify/gemm_multi_3args_beta0.cpp index 3d501ef60e8..efcf3263511 100644 --- a/test/verify/gemm_multi_3args_beta0.cpp +++ b/test/verify/gemm_multi_3args_beta0.cpp @@ -28,15 +28,16 @@ #include #include -struct gemm_multi_3args_beta0 : verify_program +template +struct gemm_multi_3args_beta0 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}}; + migraphx::shape m1_shape{DType, {1, 2, 3}}; + migraphx::shape m2_shape{DType, {1, 3, 4}}; + migraphx::shape m3_shape{DType, {1, 2, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto l3 = mm->add_parameter("3", m3_shape); @@ -47,3 +48,6 @@ struct gemm_multi_3args_beta0 : verify_program return p; } }; + +template struct gemm_multi_3args_beta0; +template struct gemm_multi_3args_beta0; diff --git a/test/verify/gemm_multi_3args_c25.cpp b/test/verify/gemm_multi_3args_c25.cpp index f2125b0c3f7..74eb60c0898 100644 --- a/test/verify/gemm_multi_3args_c25.cpp +++ b/test/verify/gemm_multi_3args_c25.cpp @@ -28,15 +28,16 @@ #include #include -struct gemm_multi_3args_c25 : verify_program +template +struct gemm_multi_3args_c25 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {2, 5}}; + migraphx::shape m1_shape{DType, {2, 3}}; + migraphx::shape m2_shape{DType, {3, 5}}; + migraphx::shape m3_shape{DType, {2, 5}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); @@ -47,3 +48,7 @@ struct gemm_multi_3args_c25 : verify_program return p; } }; + +template struct gemm_multi_3args_c25; +template struct gemm_multi_3args_c25; + diff --git a/test/verify/gemm_multi_dim_2.cpp b/test/verify/gemm_multi_dim_2.cpp index a393e3e0649..f2e5a50f666 100644 --- a/test/verify/gemm_multi_dim_2.cpp +++ b/test/verify/gemm_multi_dim_2.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_multi_dim_2 : verify_program +template +struct gemm_multi_dim_2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape m1_shape{DType, {2, 2, 3}}; + migraphx::shape m2_shape{DType, {2, 3, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); @@ -43,3 +44,6 @@ struct gemm_multi_dim_2 : verify_program return p; } }; + +template struct gemm_multi_dim_2; +template struct gemm_multi_dim_2; diff --git a/test/verify/gemm_multi_dim_2_3.cpp b/test/verify/gemm_multi_dim_2_3.cpp index 6e9d4f66923..70d6222b571 100644 --- a/test/verify/gemm_multi_dim_2_3.cpp +++ b/test/verify/gemm_multi_dim_2_3.cpp @@ -27,14 +27,15 @@ #include #include -struct gemm_multi_dim_2_3 : verify_program +template +struct gemm_multi_dim_2_3 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}}; + migraphx::shape m1_shape{DType, {2, 3, 2, 3}}; + migraphx::shape m2_shape{DType, {2, 3, 3, 2}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); @@ -43,3 +44,6 @@ struct gemm_multi_dim_2_3 : verify_program return p; } }; + +template struct gemm_multi_dim_2_3; +template struct gemm_multi_dim_2_3; diff --git a/test/verify/gemm_multi_transpose.cpp b/test/verify/gemm_multi_transpose.cpp index d683c05fee2..d3d77698b9e 100644 --- a/test/verify/gemm_multi_transpose.cpp +++ b/test/verify/gemm_multi_transpose.cpp @@ -28,14 +28,15 @@ #include #include -struct gemm_multi_transpose : verify_program +template +struct gemm_multi_transpose : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2, 4}}; + migraphx::shape m1_shape{DType, {2, 2, 3}}; + migraphx::shape m2_shape{DType, {3, 2, 4}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto tl2 = @@ -47,3 +48,6 @@ struct gemm_multi_transpose : verify_program return p; } }; + +template struct gemm_multi_transpose; +template struct gemm_multi_transpose; diff --git a/test/verify/test_gemm.cpp b/test/verify/test_gemm.cpp index 770062701da..38c2b7b154d 100644 --- a/test/verify/test_gemm.cpp +++ b/test/verify/test_gemm.cpp @@ -26,16 +26,19 @@ #include #include #include - -struct test_gemm : verify_program +template +struct test_gemm : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); - auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); + auto a = mm->add_parameter("a", migraphx::shape{DType, {4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{DType, {5, 3}}); mm->add_instruction(migraphx::make_op("dot"), a, b); return p; } }; + +template struct test_gemm; +template struct test_gemm; diff --git a/test/verify/test_gemm_copy.cpp b/test/verify/test_gemm_copy.cpp index e7cfc76d842..360314ae2b1 100644 --- a/test/verify/test_gemm_copy.cpp +++ b/test/verify/test_gemm_copy.cpp @@ -28,15 +28,16 @@ #include #include -struct test_gemm_copy : verify_program +template +struct test_gemm_copy : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; - migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; - migraphx::shape sc{migraphx::shape::float_type, {1, 8}}; + migraphx::shape sa{DType, {2, 16}}; + migraphx::shape sb{DType, {16, 8}}; + migraphx::shape sc{DType, {1, 8}}; auto pa = mm->add_parameter("a", sa); auto pb = mm->add_parameter("b", sb); auto pc = mm->add_parameter("c", sc); @@ -46,3 +47,6 @@ struct test_gemm_copy : verify_program return p; } }; + +template struct test_gemm_copy; +template struct test_gemm_copy; diff --git a/test/verify/test_gemm_ex.cpp b/test/verify/test_gemm_ex.cpp index a1fadf6cca4..57fc2a79311 100644 --- a/test/verify/test_gemm_ex.cpp +++ b/test/verify/test_gemm_ex.cpp @@ -27,15 +27,18 @@ #include #include -struct test_gemm_ex : verify_program +template +struct test_gemm_ex : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 5}}); - auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); + auto a = mm->add_parameter("a", migraphx::shape{DType, {1, 1, 4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{DType, {1, 1, 5, 3}}); mm->add_instruction(migraphx::make_op("dot"), a, b); return p; } }; +template struct test_gemm_ex; +template struct test_gemm_ex; diff --git a/test/verify/test_gemm_transposea.cpp b/test/verify/test_gemm_transposea.cpp index 403293849ac..40922ee4f6f 100644 --- a/test/verify/test_gemm_transposea.cpp +++ b/test/verify/test_gemm_transposea.cpp @@ -27,16 +27,20 @@ #include #include -struct test_gemm_transposea : verify_program +template +struct test_gemm_transposea : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); - auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); + auto a = mm->add_parameter("a", migraphx::shape{DType, {5, 4}}); + auto b = mm->add_parameter("b", migraphx::shape{DType, {5, 3}}); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a); mm->add_instruction(migraphx::make_op("dot"), at, b); return p; } }; + +template struct test_gemm_transposea; +template struct test_gemm_transposea; diff --git a/test/verify/test_gemm_transposea_ex.cpp b/test/verify/test_gemm_transposea_ex.cpp index 8001171661c..c3ffe7da253 100644 --- a/test/verify/test_gemm_transposea_ex.cpp +++ b/test/verify/test_gemm_transposea_ex.cpp @@ -27,17 +27,21 @@ #include #include -struct test_gemm_transposea_ex : verify_program +template +struct test_gemm_transposea_ex : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); - auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); + auto a = mm->add_parameter("a", migraphx::shape{DType, {1, 1, 5, 4}}); + auto b = mm->add_parameter("b", migraphx::shape{DType, {1, 1, 5, 3}}); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), a); mm->add_instruction(migraphx::make_op("dot"), at, b); return p; } }; + +template struct test_gemm_transposea_ex; +template struct test_gemm_transposea_ex; diff --git a/test/verify/test_gemm_transposeab.cpp b/test/verify/test_gemm_transposeab.cpp index 402cc57914e..5f6d70dd9e5 100644 --- a/test/verify/test_gemm_transposeab.cpp +++ b/test/verify/test_gemm_transposeab.cpp @@ -27,17 +27,21 @@ #include #include -struct test_gemm_transposeab : verify_program +template +struct test_gemm_transposeab : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); - auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); + auto a = mm->add_parameter("a", migraphx::shape{DType, {5, 4}}); + auto b = mm->add_parameter("b", migraphx::shape{DType, {3, 5}}); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); mm->add_instruction(migraphx::make_op("dot"), at, bt); return p; } }; + +template struct test_gemm_transposeab; +template struct test_gemm_transposeab; diff --git a/test/verify/test_gemm_transposeb.cpp b/test/verify/test_gemm_transposeb.cpp index 4013a6e9b23..37455a5a18e 100644 --- a/test/verify/test_gemm_transposeb.cpp +++ b/test/verify/test_gemm_transposeb.cpp @@ -27,16 +27,20 @@ #include #include -struct test_gemm_transposeb : verify_program +template +struct test_gemm_transposeb : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); - auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); + auto a = mm->add_parameter("a", migraphx::shape{DType, {4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{DType, {3, 5}}); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); mm->add_instruction(migraphx::make_op("dot"), a, bt); return p; } }; + +template struct test_gemm_transposeb; +template struct test_gemm_transposeb; diff --git a/test/verify/test_gemm_transposeb_ex.cpp b/test/verify/test_gemm_transposeb_ex.cpp index 3ff680b964c..080fe721545 100644 --- a/test/verify/test_gemm_transposeb_ex.cpp +++ b/test/verify/test_gemm_transposeb_ex.cpp @@ -26,18 +26,21 @@ #include #include #include - -struct test_gemm_transposeb_ex : verify_program +template +struct test_gemm_transposeb_ex : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}}); - auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); + auto a = mm->add_parameter("a", migraphx::shape{DType, {1, 4, 5}}); + auto b = mm->add_parameter("b", migraphx::shape{DType, {1, 3, 5}}); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); mm->add_instruction(migraphx::make_op("dot"), a, bt); return p; } }; + +template struct test_gemm_transposeb_ex; +template struct test_gemm_transposeb_ex; diff --git a/test/verify/test_mul_dot_a.cpp b/test/verify/test_mul_dot_a.cpp index 5d6e39dc0c5..ece43386d52 100644 --- a/test/verify/test_mul_dot_a.cpp +++ b/test/verify/test_mul_dot_a.cpp @@ -27,17 +27,17 @@ #include #include -struct test_mul_dot_a : verify_program +template +struct test_mul_dot_a : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; - migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; + migraphx::shape as{DType, {2, 256, 32}}; + migraphx::shape bs{DType, {2, 32, 128}}; auto a = mm->add_parameter("input", as); - auto lit = - mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}})); + auto lit = mm->add_literal(migraphx::generate_literal({DType, {1, 1, 32}})); auto litb = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit); auto mul = mm->add_instruction(migraphx::make_op("mul"), a, litb); @@ -47,3 +47,6 @@ struct test_mul_dot_a : verify_program return p; } }; + +template struct test_mul_dot_a; +template struct test_mul_dot_a; diff --git a/test/verify/test_mul_dot_b.cpp b/test/verify/test_mul_dot_b.cpp index 7ff1d3db962..a275a06b4ed 100644 --- a/test/verify/test_mul_dot_b.cpp +++ b/test/verify/test_mul_dot_b.cpp @@ -27,17 +27,18 @@ #include #include -struct test_mul_dot_b : verify_program +template + +struct test_mul_dot_b : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; - migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; + migraphx::shape as{DType, {2, 256, 32}}; + migraphx::shape bs{DType, {2, 32, 128}}; auto b = mm->add_parameter("input", bs); - auto lit = - mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 32, 1}})); + auto lit = mm->add_literal(migraphx::generate_literal({DType, {1, 32, 1}})); auto litb = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit); auto mul = mm->add_instruction(migraphx::make_op("mul"), b, litb); @@ -47,3 +48,6 @@ struct test_mul_dot_b : verify_program return p; } }; + +template struct test_mul_dot_b; +template struct test_mul_dot_b; diff --git a/test/verify/test_unbatched_gemm_1.cpp b/test/verify/test_unbatched_gemm_1.cpp index 92e7329e635..7b40af94615 100644 --- a/test/verify/test_unbatched_gemm_1.cpp +++ b/test/verify/test_unbatched_gemm_1.cpp @@ -27,15 +27,17 @@ #include #include #include -struct test_unbatched_gemm_1 : verify_program + +template +struct test_unbatched_gemm_1 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {2, 32, 64}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}}; - migraphx::shape m3_shape{migraphx::shape::float_type, {2, 32, 192}}; + migraphx::shape m1_shape{DType, {2, 32, 64}}; + migraphx::shape m2_shape{DType, {64, 64}}; + migraphx::shape m3_shape{DType, {2, 32, 192}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape)); l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}), @@ -56,3 +58,6 @@ struct test_unbatched_gemm_1 : verify_program return p; } }; + +template struct test_unbatched_gemm_1; +template struct test_unbatched_gemm_1; diff --git a/test/verify/test_unbatched_gemm_2.cpp b/test/verify/test_unbatched_gemm_2.cpp index 204f27e5985..a2c16ffe26b 100644 --- a/test/verify/test_unbatched_gemm_2.cpp +++ b/test/verify/test_unbatched_gemm_2.cpp @@ -27,14 +27,16 @@ #include #include #include -struct test_unbatched_gemm_2 : verify_program + +template +struct test_unbatched_gemm_2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::float_type, {4, 32, 64}}; - migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}}; + migraphx::shape m1_shape{DType, {4, 32, 64}}; + migraphx::shape m2_shape{DType, {64, 64}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape)); l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 64, 64}}}), @@ -44,3 +46,6 @@ struct test_unbatched_gemm_2 : verify_program return p; } }; + +template struct test_unbatched_gemm_2; +template struct test_unbatched_gemm_2; From ad9c25eaf9ba1f0d3e6084810846786bedef7f58 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 19:15:33 +0000 Subject: [PATCH 067/115] add eliminate_fp8 pass --- src/CMakeLists.txt | 1 + src/eliminate_fp8.cpp | 62 +++++++++++++++++++++ src/include/migraphx/eliminate_fp8.hpp | 52 ++++++++++++++++++ src/targets/gpu/gemm_impl.cpp | 76 +++++--------------------- src/targets/gpu/target.cpp | 7 +++ 5 files changed, 137 insertions(+), 61 deletions(-) create mode 100644 src/eliminate_fp8.cpp create mode 100644 src/include/migraphx/eliminate_fp8.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 38a2b14d80c..da9d3a09d75 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -49,6 +49,7 @@ add_library(migraphx eliminate_concat.cpp eliminate_contiguous.cpp eliminate_data_type.cpp + eliminate_fp8.cpp eliminate_identity.cpp eliminate_pad.cpp env.cpp diff --git a/src/eliminate_fp8.cpp b/src/eliminate_fp8.cpp new file mode 100644 index 00000000000..402faa2c1c7 --- /dev/null +++ b/src/eliminate_fp8.cpp @@ -0,0 +1,62 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void eliminate_fp8::apply(module& m) const +{ + for(auto ins : iterator_for(m)) + { + if(not contains(op_names, ins->name())) + continue; + migraphx::shape::type_t orig_type = ins->get_shape().type(); + std::vector orig_inputs = ins->inputs(); + std::vector new_inputs; + for(const auto& i : orig_inputs) + { + new_inputs.push_back(m.insert_instruction( + ins, + migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}), + i)); + } + auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs}); + auto convert_back_ins = m.insert_instruction( + ins, + migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}), + new_ins); + m.replace_instruction(ins, convert_back_ins); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/include/migraphx/eliminate_fp8.hpp b/src/include/migraphx/eliminate_fp8.hpp new file mode 100644 index 00000000000..c3304dd054e --- /dev/null +++ b/src/include/migraphx/eliminate_fp8.hpp @@ -0,0 +1,52 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP +#define MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +/** +This will insert convert operators for the operators that are not implemented for FP8 dtypes + */ +struct MIGRAPHX_EXPORT eliminate_fp8 +{ + // TODO: Add all device ops as a later PR and add tests for those. + std::set op_names; + shape::type_t target_type = migraphx::shape::float_type; + std::string name() const { return "eliminate_fp8"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 6d30df64718..057e81395af 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -227,18 +227,20 @@ struct gemm_impl { if(strided_batched) { - auto common_args = create_strided_batched_args_common_fp8(ctx, input_args); + auto common_args = create_strided_batched_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex3, common_args, + rocblas_compute_type_f32, rocblas_gemm_algo_standard, solution_idx, gemm_flags); } else { - auto common_args = create_gemm_ex_args_common_fp8(ctx, input_args); + auto common_args = create_gemm_ex_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_ex3, common_args, + rocblas_compute_type_f32, rocblas_gemm_algo_standard, solution_idx, gemm_flags); @@ -252,6 +254,7 @@ struct gemm_impl auto common_args = create_strided_batched_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex, common_args, + compute_type, rocblas_gemm_algo_solution_index, solution_idx, gemm_flags); @@ -261,6 +264,7 @@ struct gemm_impl auto common_args = create_gemm_ex_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_ex, common_args, + compute_type, rocblas_gemm_algo_solution_index, solution_idx, gemm_flags); @@ -300,6 +304,7 @@ struct gemm_impl auto common_args = create_strided_batched_args_common(ctx, input_args); check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex, common_args, + compute_type, rocblas_gemm_algo_solution_index, solution_idx, rocblas_gemm_flags_check_solution_index); @@ -309,6 +314,7 @@ struct gemm_impl auto common_args = create_gemm_ex_args_common(ctx, input_args); check_valid = rocblas_invoke(&rocblas_gemm_ex, common_args, + compute_type, rocblas_gemm_algo_solution_index, solution_idx, rocblas_gemm_flags_check_solution_index); @@ -359,40 +365,8 @@ struct gemm_impl output_type, ldd, d_stride, - num_matrices, - compute_type); + num_matrices); } - auto create_strided_batched_args_common_fp8(context& ctx, - const std::vector& args) const - { - return pack(ctx.get_stream().get_rocblas(), - transb ? rocblas_operation_transpose : rocblas_operation_none, - transa ? rocblas_operation_transpose : rocblas_operation_none, - n, - m, - k, - get_alpha(), - args[1].data(), - arg_type, - ldb, - b_stride, - args[0].data(), - arg_type, - lda, - a_stride, - get_beta(), - args[2].data(), - output_type, - ldc, - c_stride, - is_3inputs ? args[3].data() : args[2].data(), - output_type, - ldd, - d_stride, - num_matrices, - rocblas_compute_type_f32); - } - /** * Helper method to create that subset of a long rocBLAS argument list that is common * to multiple "gemm_ex..." calls. @@ -424,33 +398,9 @@ struct gemm_impl ldc, is_3inputs ? args[3].data() : args[2].data(), output_type, - ldd, - compute_type); - } - auto create_gemm_ex_args_common_fp8(context& ctx, const std::vector& args) const - { - return pack(ctx.get_stream().get_rocblas(), - transb ? rocblas_operation_transpose : rocblas_operation_none, - transa ? rocblas_operation_transpose : rocblas_operation_none, - n, - m, - k, - get_alpha(), - args[1].data(), - arg_type, - ldb, - args[0].data(), - arg_type, - lda, - get_beta(), - args[2].data(), - output_type, - ldc, - is_3inputs ? args[3].data() : args[2].data(), - output_type, - ldd, - rocblas_compute_type_f32); + ldd); } + #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API /** * Find best rocBLAS solution: Get list of solutions and try them all, returning the index @@ -478,6 +428,7 @@ struct gemm_impl auto common_args = create_strided_batched_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, common_args, + compute_type, rocblas_gemm_algo_solution_index, gemm_flags, nullptr, @@ -487,6 +438,7 @@ struct gemm_impl auto common_sol_args = create_strided_batched_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, common_sol_args, + compute_type, rocblas_gemm_algo_solution_index, gemm_flags, solution_indices.data(), @@ -497,6 +449,7 @@ struct gemm_impl auto common_args = create_gemm_ex_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_ex_get_solutions, common_args, + compute_type, rocblas_gemm_algo_solution_index, gemm_flags, nullptr, @@ -506,6 +459,7 @@ struct gemm_impl auto common_sol_args = create_gemm_ex_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_ex_get_solutions, common_sol_args, + compute_type, rocblas_gemm_algo_solution_index, gemm_flags, solution_indices.data(), diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index dc1f8cd7991..b5ae99b074f 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -52,6 +52,7 @@ #include #include #include +#include #include #include #include @@ -105,6 +106,11 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::tuple_type); + std::set unsupported_fp8_ops = {}; + if(not gpu::rocblas_fp8_available()) + { + unsupported_fp8_ops.insert("dot"); + } // clang-format off return { @@ -147,6 +153,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, enable_pass(mlir_enabled(), fuse_mlir{&ctx}), dead_code_elimination{}, + eliminate_fp8{unsupported_fp8_ops}, lowering{&ctx, options.offload_copy}, eliminate_contiguous{"gpu::contiguous"}, dead_code_elimination{}, From 8734ffa39d3c432db62dd1cbfc978c2e7abb9f31 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 19:30:32 +0000 Subject: [PATCH 068/115] remove convert from lowering --- src/targets/gpu/lowering.cpp | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index bd87ba2800e..ea0e29f8853 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -220,44 +220,12 @@ struct miopen_apply return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}})); } - instruction_ref convert_fp8_to_fp32(instruction_ref ins) - { - std::vector fp8_inputs = ins->inputs(); - std::vector fp32_inputs; - for(const auto& i : fp8_inputs) - { - fp32_inputs.push_back(mod->insert_instruction( - ins, - migraphx::make_op( - "convert", - {{"target_type", migraphx::to_value(migraphx::shape::type_t::float_type)}}), - i)); - } - auto fp32_ins = mod->insert_instruction(ins, ins->get_operator(), {fp32_inputs}); - auto fp8_ins = mod->insert_instruction( - ins, - migraphx::make_op( - "convert", - {{"target_type", migraphx::to_value(migraphx::shape::type_t::fp8e4m3fnuz_type)}}), - fp32_ins); - mod->replace_instruction(ins, fp8_ins); - return fp32_ins; - } - template void add_gemm_op(const std::string& name) { apply_map.emplace(name, [=](instruction_ref ins) { std::vector refs = ins->inputs(); assert(refs.size() == 2); - if(not rocblas_fp8_available() and - std::any_of(refs.begin(), refs.end(), [](const auto i) { - return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type; - })) - { - // replace fp8 ins with fp32 ins - ins = convert_fp8_to_fp32(ins); - } auto output = insert_allocation(ins, ins->get_shape()); refs.push_back(output); return mod->replace_instruction(ins, rocblas_gemm{Op{}, 1, 0, compute_fp32}, refs); From f014fb963afc8b92c7474bc28da74f9cb3355fbb Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 19:37:16 +0000 Subject: [PATCH 069/115] Fix eliminate_fp8 pass --- src/eliminate_fp8.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/eliminate_fp8.cpp b/src/eliminate_fp8.cpp index 402faa2c1c7..77793337d7b 100644 --- a/src/eliminate_fp8.cpp +++ b/src/eliminate_fp8.cpp @@ -37,7 +37,8 @@ void eliminate_fp8::apply(module& m) const { for(auto ins : iterator_for(m)) { - if(not contains(op_names, ins->name())) + if(not contains(op_names, ins->name()) or + ins->get_shape().type() != migraphx::shape::fp8e4m3fnuz_type) continue; migraphx::shape::type_t orig_type = ins->get_shape().type(); std::vector orig_inputs = ins->inputs(); From 83ce487ab1b1f0e892829ccf1fa3b13ee7a09a6f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 21:59:57 +0000 Subject: [PATCH 070/115] Move pass before optimize module --- src/targets/gpu/target.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index b5ae99b074f..d6aadf52569 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -142,6 +142,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti prefuse_ops{}, dead_code_elimination{}, auto_contiguous{}, + eliminate_fp8{unsupported_fp8_ops}, + dead_code_elimination{}, optimize_module{}, fuse_pointwise{}, dead_code_elimination{}, @@ -153,7 +155,6 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, enable_pass(mlir_enabled(), fuse_mlir{&ctx}), dead_code_elimination{}, - eliminate_fp8{unsupported_fp8_ops}, lowering{&ctx, options.offload_copy}, eliminate_contiguous{"gpu::contiguous"}, dead_code_elimination{}, From 9a9e96484ab17e88660d371ac522b752caa125a7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 22:11:55 +0000 Subject: [PATCH 071/115] formatting --- src/targets/gpu/rocblas.cpp | 3 ++- test/verify/gemm_2args_bmv.cpp | 3 +-- test/verify/gemm_multi_3args_c25.cpp | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/rocblas.cpp b/src/targets/gpu/rocblas.cpp index d73813c2283..59452408801 100644 --- a/src/targets/gpu/rocblas.cpp +++ b/src/targets/gpu/rocblas.cpp @@ -53,7 +53,8 @@ bool get_compute_fp32_flag() return (starts_with(device_name, "gfx9") and device_name >= "gfx908"); } -bool rocblas_fp8_available() { +bool rocblas_fp8_available() +{ #ifndef MIGRAPHX_USE_ROCBLAS_FP8_API return false; #else diff --git a/test/verify/gemm_2args_bmv.cpp b/test/verify/gemm_2args_bmv.cpp index d192de96bab..51f8a297900 100644 --- a/test/verify/gemm_2args_bmv.cpp +++ b/test/verify/gemm_2args_bmv.cpp @@ -27,7 +27,7 @@ #include #include -template +template struct gemm_2args_bmv : verify_program> { migraphx::program create_program() const @@ -50,4 +50,3 @@ struct gemm_2args_bmv : verify_program> template struct gemm_2args_bmv; template struct gemm_2args_bmv; - diff --git a/test/verify/gemm_multi_3args_c25.cpp b/test/verify/gemm_multi_3args_c25.cpp index 74eb60c0898..47a47c30af4 100644 --- a/test/verify/gemm_multi_3args_c25.cpp +++ b/test/verify/gemm_multi_3args_c25.cpp @@ -51,4 +51,3 @@ struct gemm_multi_3args_c25 : verify_program> template struct gemm_multi_3args_c25; template struct gemm_multi_3args_c25; - From c40a39c34103b5fa35f145c69c73cd15564be201 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 22:55:34 +0000 Subject: [PATCH 072/115] fix cppcheck --- src/eliminate_fp8.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/eliminate_fp8.cpp b/src/eliminate_fp8.cpp index 77793337d7b..62b8f8ba6f5 100644 --- a/src/eliminate_fp8.cpp +++ b/src/eliminate_fp8.cpp @@ -21,6 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/serialize.hpp" +#include #include #include #include @@ -43,13 +45,17 @@ void eliminate_fp8::apply(module& m) const migraphx::shape::type_t orig_type = ins->get_shape().type(); std::vector orig_inputs = ins->inputs(); std::vector new_inputs; - for(const auto& i : orig_inputs) - { - new_inputs.push_back(m.insert_instruction( - ins, - migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}), - i)); - } + std::transform(orig_inputs.begin(), + orig_inputs.end(), + std::back_inserter(new_inputs), + [&](const auto& i) { + return m.insert_instruction( + ins, + migraphx::make_op( + "convert", {{"target_type", migraphx::to_value(target_type)}}), + i); + }); + auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs}); auto convert_back_ins = m.insert_instruction( ins, From f155b0e6db4f88d479b155bd2356baf96ab905e1 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 1 Dec 2023 23:27:49 +0000 Subject: [PATCH 073/115] merge changes --- src/eliminate_fp8.cpp | 2 +- .../include/migraphx/kernels/bit_cast.hpp | 1 - .../include/migraphx/kernels/float8.hpp | 8 ++--- .../include/migraphx/kernels/layernorm.hpp | 1 - .../kernels/include/migraphx/kernels/ops.hpp | 2 +- .../kernels/include/migraphx/kernels/pad.hpp | 1 - .../include/migraphx/kernels/roialign.hpp | 36 +++++++++---------- .../include/migraphx/kernels/softmax.hpp | 1 - 8 files changed, 22 insertions(+), 30 deletions(-) diff --git a/src/eliminate_fp8.cpp b/src/eliminate_fp8.cpp index 62b8f8ba6f5..9a43253b94b 100644 --- a/src/eliminate_fp8.cpp +++ b/src/eliminate_fp8.cpp @@ -21,7 +21,6 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include "migraphx/serialize.hpp" #include #include #include @@ -30,6 +29,7 @@ #include #include #include +#include #include namespace migraphx { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp index d66a36959f7..c98395bbe10 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp @@ -22,7 +22,6 @@ #ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP - #include namespace migraphx { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index da87d6e2314..d16c062299b 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -394,7 +394,6 @@ using fp8e5m2fnuz = float8; } // NOLINTNEXTLINE - #define MIGRAPHX_FP8_OTHER_OPS(T) \ inline constexpr __device__ T fabs(T v) \ { \ @@ -502,7 +501,6 @@ class numeric_limits { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); } - // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we // want to make this distinction. For the floating points we would end up using lowest most of // the times. @@ -530,7 +528,9 @@ class numeric_limits } static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } - // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. + // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we + // want to make this distinction. For the floating points we would end up using lowest most of + // the times. static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } @@ -539,7 +539,6 @@ class numeric_limits }; } // namespace fp8 - template {} or is_same{} or is_same{} or is_same{})> @@ -560,7 +559,6 @@ constexpr T numeric_lowest(migraphx::fp8::f8_type unused = migraphx::fp8::f8_typ (void)(unused); return fp8::numeric_limits::lowest(); } - } // namespace migraphx // ================================================================================================= #if defined(__clang__) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp index 048bd450c64..e186c457ed0 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp @@ -54,7 +54,6 @@ __device__ void generic_binary_layernorm( using value_type = typename Input1::type; using vec_value_type = vec_type; constexpr auto relements = r.template elements(); - constexpr auto relements_r = vec_value_type{1.0 / relements}; auto relements_rsqrt = sqrt(relements_r); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp index f70e4933347..898ce637e6a 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp @@ -118,7 +118,7 @@ struct highest template constexpr operator T() const { - return numeric_max, void>(); + return numeric_max>(); } }; } // namespace migraphx diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp index 47af1cc7e46..95a7e05763c 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp @@ -40,7 +40,6 @@ __device__ void pad(const index& idx, const PadVal& pad_val) { auto output_shape = output.get_shape(); - using otype = typename Output::type; idx.global_stride(output_shape.elements(), [&](auto i) { // 1. get current multi-index for output // 2. get the size of the input to determine input boundaries diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp index 706eb925c67..93cc18e28bd 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp @@ -56,7 +56,7 @@ struct avg_pool template MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) { - return static_cast(x + y); + return x + y; } template @@ -70,7 +70,6 @@ template MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( const Iterator data, const array& dims, array xy, Op pooling) { - using ret_type = typename Iterator::value_type; array low{}; array high{}; for(index_int ii = 0; ii < xy.size(); ++ii) @@ -93,7 +92,6 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate( high[0] * dims[1] + low[1], high[0] * dims[1] + high[1]}; - float ly = xy[0] - low[0]; float lx = xy[1] - low[1]; float hy = 1.0f - ly; @@ -204,25 +202,25 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); if constexpr(s.is_avg_pooling) { - y_t[i] = static_cast(calc_pooling(offset_x, - roi_starts, - bin_size, - {ph, pw}, - bin_grid_size, - in_dims, - s.roi_offset, - avg_pool{})); + y_t[i] = calc_pooling(offset_x, + roi_starts, + bin_size, + {ph, pw}, + bin_grid_size, + in_dims, + s.roi_offset, + avg_pool{}); } else { - y_t[i] = static_cast(calc_pooling(offset_x, - roi_starts, - bin_size, - {ph, pw}, - bin_grid_size, - in_dims, - s.roi_offset, - max_pool{})); + y_t[i] = calc_pooling(offset_x, + roi_starts, + bin_size, + {ph, pw}, + bin_grid_size, + in_dims, + s.roi_offset, + max_pool{}); } } } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp index c713cc5ca15..b967d76b368 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp @@ -33,7 +33,6 @@ template __device__ void softmax(Input input1, Output output) { using block = reduce::auto_block()>; - using otype = typename Output::type; block::template run>([&](auto, auto r) { auto input = r.inner(op::id{})(input1); #ifdef MIGRAPHX_USE_FAST_SOFTMAX From 38218edccd202856b2ce26ce2824c84cbf1aa8d8 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 1 Dec 2023 23:30:09 +0000 Subject: [PATCH 074/115] few changes --- src/targets/gpu/jit/reduce.cpp | 29 +++++++++---------- .../include/migraphx/kernels/float8.hpp | 8 ++--- test/verify/gemm_2args_mm_2.cpp | 2 +- test/verify/gemm_2args_mm_3.cpp | 2 +- test/verify/gemm_2args_mm_4.cpp | 2 +- 5 files changed, 19 insertions(+), 24 deletions(-) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index e001b5ba510..1e018d2633e 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -146,7 +146,6 @@ struct simple_reduce_compiler : compiler vectorize vec{}; auto nelements = options.virtual_inputs.back().elements(); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); - if(algo == "block") { // Vectorize if the axis is a reduction axis @@ -170,13 +169,13 @@ struct simple_reduce_compiler : compiler options.kernel_name = "reduce_kernel"; std::string identity = "[](auto x) { return x; }"; auto src = interpolate_string(simple_reduce_kernel, - {{"reduction", v.at("reduction").to()}, - {"init", v.get("init", std::string{"0"})}, - {"read", v.get("read", identity)}, - {"write", v.get("write", identity)}, - {"algo", algo}, - {"transformers", make_transformer_args(vec)}, - {"preamble", v.get("preamble", std::string{})}}); + {{"reduction", v.at("reduction").to()}, + {"init", v.get("init", std::string{"0"})}, + {"read", v.get("read", identity)}, + {"write", v.get("write", identity)}, + {"algo", algo}, + {"transformers", make_transformer_args(vec)}, + {"preamble", v.get("preamble", std::string{})}}); options.params += "-Wno-float-equal"; return compile_hip_code_object(src, options); } @@ -267,13 +266,13 @@ struct fused_reduce_compiler : compiler auto src = interpolate_string( fused_reduce_kernel, {{"kernel", options.kernel_name}, - {"params", enum_params(inputs.size(), "void * private_p")}, - {"args", enum_params(inputs.size(), "private_p")}, - {"algo", algo}, - {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, - {"lambda", v.at("lambda").to()}, - {"transformers", make_transformer_args(vec)}, - {"preamble", v.get("preamble", std::string{})}}); + {"params", enum_params(inputs.size(), "void * private_p")}, + {"args", enum_params(inputs.size(), "private_p")}, + {"algo", algo}, + {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, + {"lambda", v.at("lambda").to()}, + {"transformers", make_transformer_args(vec)}, + {"preamble", v.get("preamble", std::string{})}}); options.params += "-Wno-float-equal"; return compile_hip_code_object(src, options); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp index d16c062299b..543a2165685 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp @@ -501,9 +501,7 @@ class numeric_limits { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); } - // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we - // want to make this distinction. For the floating points we would end up using lowest most of - // the times. + // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. static constexpr __device__ fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); @@ -528,9 +526,7 @@ class numeric_limits } static constexpr __device__ fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } - // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we - // want to make this distinction. For the floating points we would end up using lowest most of - // the times. + // this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. static constexpr __device__ fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } static constexpr __device__ fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } diff --git a/test/verify/gemm_2args_mm_2.cpp b/test/verify/gemm_2args_mm_2.cpp index 448a76c7c99..e6f184d84a2 100644 --- a/test/verify/gemm_2args_mm_2.cpp +++ b/test/verify/gemm_2args_mm_2.cpp @@ -22,9 +22,9 @@ * THE SOFTWARE. */ -#include "migraphx/shape.hpp" #include "verify_program.hpp" #include +#include #include #include diff --git a/test/verify/gemm_2args_mm_3.cpp b/test/verify/gemm_2args_mm_3.cpp index 668a55f1c88..e41c2b42a91 100644 --- a/test/verify/gemm_2args_mm_3.cpp +++ b/test/verify/gemm_2args_mm_3.cpp @@ -22,9 +22,9 @@ * THE SOFTWARE. */ -#include "migraphx/shape.hpp" #include "verify_program.hpp" #include +#include #include #include diff --git a/test/verify/gemm_2args_mm_4.cpp b/test/verify/gemm_2args_mm_4.cpp index c0e02d8dfa6..bbe13b6f6ab 100644 --- a/test/verify/gemm_2args_mm_4.cpp +++ b/test/verify/gemm_2args_mm_4.cpp @@ -22,8 +22,8 @@ * THE SOFTWARE. */ -#include "migraphx/shape.hpp" #include "verify_program.hpp" +#include #include #include #include From 379692fcec1a425ce605286f1d60666b40120605 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 1 Dec 2023 23:32:03 +0000 Subject: [PATCH 075/115] few more cosmetic changes --- test/verify/test_gemm_transposeb_ex.cpp | 1 + test/verify/test_mul_dot_b.cpp | 1 - test/verify/test_nearbyint.cpp | 1 - test/verify/test_reduce_add.cpp | 1 - test/verify/test_scatternd.cpp | 1 - 5 files changed, 1 insertion(+), 4 deletions(-) diff --git a/test/verify/test_gemm_transposeb_ex.cpp b/test/verify/test_gemm_transposeb_ex.cpp index 080fe721545..8d8b305ba71 100644 --- a/test/verify/test_gemm_transposeb_ex.cpp +++ b/test/verify/test_gemm_transposeb_ex.cpp @@ -26,6 +26,7 @@ #include #include #include + template struct test_gemm_transposeb_ex : verify_program> { diff --git a/test/verify/test_mul_dot_b.cpp b/test/verify/test_mul_dot_b.cpp index a275a06b4ed..34d36165797 100644 --- a/test/verify/test_mul_dot_b.cpp +++ b/test/verify/test_mul_dot_b.cpp @@ -28,7 +28,6 @@ #include template - struct test_mul_dot_b : verify_program> { migraphx::program create_program() const diff --git a/test/verify/test_nearbyint.cpp b/test/verify/test_nearbyint.cpp index 7bd9d5a8ceb..c1b3f972bf1 100644 --- a/test/verify/test_nearbyint.cpp +++ b/test/verify/test_nearbyint.cpp @@ -22,7 +22,6 @@ * THE SOFTWARE. */ -#include "migraphx/float8.hpp" #include "verify_program.hpp" #include #include diff --git a/test/verify/test_reduce_add.cpp b/test/verify/test_reduce_add.cpp index 554497deb12..4c4a1ed5da7 100644 --- a/test/verify/test_reduce_add.cpp +++ b/test/verify/test_reduce_add.cpp @@ -22,7 +22,6 @@ * THE SOFTWARE. */ -#include "migraphx/shape.hpp" #include "verify_program.hpp" #include #include diff --git a/test/verify/test_scatternd.cpp b/test/verify/test_scatternd.cpp index 98e2ef58013..b8022b7be5b 100644 --- a/test/verify/test_scatternd.cpp +++ b/test/verify/test_scatternd.cpp @@ -21,7 +21,6 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include "migraphx/shape.hpp" #include "verify_program.hpp" #include #include From 381b2d9e8d6b32ca6ca43b194019ac4d5af6399b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sat, 2 Dec 2023 01:26:51 +0000 Subject: [PATCH 076/115] add half tests --- test/verify/gemm_2args_bmv.cpp | 1 + test/verify/gemm_2args_mm_1.cpp | 1 + test/verify/gemm_2args_mm_2.cpp | 1 + test/verify/gemm_2args_mm_3.cpp | 1 + test/verify/gemm_2args_mm_4.cpp | 1 + test/verify/gemm_2args_mm_5.cpp | 1 + test/verify/gemm_2args_mm_6.cpp | 1 + test/verify/gemm_2args_mm_7.cpp | 1 + test/verify/gemm_2args_mm_8.cpp | 1 + test/verify/gemm_2args_mv.cpp | 1 + test/verify/gemm_2args_vbm.cpp | 1 + test/verify/gemm_2args_vm.cpp | 1 + test/verify/gemm_2args_vv.cpp | 1 + test/verify/gemm_add.cpp | 1 + test/verify/gemm_add_broadcast1.cpp | 1 + test/verify/gemm_add_broadcast2.cpp | 1 + test/verify/gemm_add_broadcast_half.cpp | 49 ------------------------- test/verify/gemm_add_half.cpp | 47 ------------------------ test/verify/gemm_literal.cpp | 1 + test/verify/gemm_multi_3args.cpp | 1 + test/verify/gemm_multi_3args_alpha0.cpp | 1 + test/verify/gemm_multi_3args_beta0.cpp | 1 + test/verify/gemm_multi_3args_c25.cpp | 1 + test/verify/gemm_multi_dim_2.cpp | 1 + test/verify/gemm_multi_dim_2_3.cpp | 1 + test/verify/gemm_multi_transpose.cpp | 1 + test/verify/test_gemm.cpp | 1 + test/verify/test_gemm_copy.cpp | 1 + test/verify/test_gemm_ex.cpp | 1 + test/verify/test_gemm_half.cpp | 41 --------------------- test/verify/test_gemm_transposea.cpp | 1 + test/verify/test_gemm_transposea_ex.cpp | 1 + test/verify/test_gemm_transposeab.cpp | 1 + test/verify/test_gemm_transposeb.cpp | 1 + test/verify/test_gemm_transposeb_ex.cpp | 1 + test/verify/test_mul_dot_a.cpp | 1 + test/verify/test_mul_dot_b.cpp | 2 + test/verify/test_unbatched_gemm_1.cpp | 1 + test/verify/test_unbatched_gemm_2.cpp | 1 + 39 files changed, 37 insertions(+), 137 deletions(-) delete mode 100644 test/verify/gemm_add_broadcast_half.cpp delete mode 100644 test/verify/gemm_add_half.cpp delete mode 100644 test/verify/test_gemm_half.cpp diff --git a/test/verify/gemm_2args_bmv.cpp b/test/verify/gemm_2args_bmv.cpp index 51f8a297900..c6a4e72b380 100644 --- a/test/verify/gemm_2args_bmv.cpp +++ b/test/verify/gemm_2args_bmv.cpp @@ -49,4 +49,5 @@ struct gemm_2args_bmv : verify_program> }; template struct gemm_2args_bmv; +template struct gemm_2args_bmv; template struct gemm_2args_bmv; diff --git a/test/verify/gemm_2args_mm_1.cpp b/test/verify/gemm_2args_mm_1.cpp index 6df8d42dd0f..4b79db619ac 100644 --- a/test/verify/gemm_2args_mm_1.cpp +++ b/test/verify/gemm_2args_mm_1.cpp @@ -48,4 +48,5 @@ struct gemm_2args_mm_1 : verify_program> }; template struct gemm_2args_mm_1; +template struct gemm_2args_mm_1; template struct gemm_2args_mm_1; diff --git a/test/verify/gemm_2args_mm_2.cpp b/test/verify/gemm_2args_mm_2.cpp index e6f184d84a2..ec1c33f9ae0 100644 --- a/test/verify/gemm_2args_mm_2.cpp +++ b/test/verify/gemm_2args_mm_2.cpp @@ -49,4 +49,5 @@ struct gemm_2args_mm_2 : verify_program> }; template struct gemm_2args_mm_2; +template struct gemm_2args_mm_2; template struct gemm_2args_mm_2; diff --git a/test/verify/gemm_2args_mm_3.cpp b/test/verify/gemm_2args_mm_3.cpp index e41c2b42a91..750a48a2c6a 100644 --- a/test/verify/gemm_2args_mm_3.cpp +++ b/test/verify/gemm_2args_mm_3.cpp @@ -49,4 +49,5 @@ struct gemm_2args_mm_3 : verify_program> }; template struct gemm_2args_mm_3; +template struct gemm_2args_mm_3; template struct gemm_2args_mm_3; diff --git a/test/verify/gemm_2args_mm_4.cpp b/test/verify/gemm_2args_mm_4.cpp index bbe13b6f6ab..4f5c81b8265 100644 --- a/test/verify/gemm_2args_mm_4.cpp +++ b/test/verify/gemm_2args_mm_4.cpp @@ -49,4 +49,5 @@ struct gemm_2args_mm_4 : verify_program> }; template struct gemm_2args_mm_4; +template struct gemm_2args_mm_4; template struct gemm_2args_mm_4; diff --git a/test/verify/gemm_2args_mm_5.cpp b/test/verify/gemm_2args_mm_5.cpp index adb83929afd..b88dd002c69 100644 --- a/test/verify/gemm_2args_mm_5.cpp +++ b/test/verify/gemm_2args_mm_5.cpp @@ -48,4 +48,5 @@ struct gemm_2args_mm_5 : verify_program> }; template struct gemm_2args_mm_5; +template struct gemm_2args_mm_5; template struct gemm_2args_mm_5; diff --git a/test/verify/gemm_2args_mm_6.cpp b/test/verify/gemm_2args_mm_6.cpp index d6f4c51b4b0..0bbfaf82743 100644 --- a/test/verify/gemm_2args_mm_6.cpp +++ b/test/verify/gemm_2args_mm_6.cpp @@ -51,4 +51,5 @@ struct gemm_2args_mm_6 : verify_program> }; template struct gemm_2args_mm_6; +template struct gemm_2args_mm_6; template struct gemm_2args_mm_6; diff --git a/test/verify/gemm_2args_mm_7.cpp b/test/verify/gemm_2args_mm_7.cpp index 3dcb5d1ecef..3c8a7747ed9 100644 --- a/test/verify/gemm_2args_mm_7.cpp +++ b/test/verify/gemm_2args_mm_7.cpp @@ -48,4 +48,5 @@ struct gemm_2args_mm_7 : verify_program> }; template struct gemm_2args_mm_7; +template struct gemm_2args_mm_7; template struct gemm_2args_mm_7; diff --git a/test/verify/gemm_2args_mm_8.cpp b/test/verify/gemm_2args_mm_8.cpp index b4da6d40990..8779ee1b8a2 100644 --- a/test/verify/gemm_2args_mm_8.cpp +++ b/test/verify/gemm_2args_mm_8.cpp @@ -48,4 +48,5 @@ struct gemm_2args_mm_8 : verify_program> }; template struct gemm_2args_mm_8; +template struct gemm_2args_mm_8; template struct gemm_2args_mm_8; diff --git a/test/verify/gemm_2args_mv.cpp b/test/verify/gemm_2args_mv.cpp index 773ec758a13..883ce63b396 100644 --- a/test/verify/gemm_2args_mv.cpp +++ b/test/verify/gemm_2args_mv.cpp @@ -47,4 +47,5 @@ struct gemm_2args_mv : verify_program> }; template struct gemm_2args_mv; +template struct gemm_2args_mv; template struct gemm_2args_mv; diff --git a/test/verify/gemm_2args_vbm.cpp b/test/verify/gemm_2args_vbm.cpp index f525e74efc5..bc4e73f9c61 100644 --- a/test/verify/gemm_2args_vbm.cpp +++ b/test/verify/gemm_2args_vbm.cpp @@ -51,4 +51,5 @@ struct gemm_2args_vbm : verify_program> }; template struct gemm_2args_vbm; +template struct gemm_2args_vbm; template struct gemm_2args_vbm; diff --git a/test/verify/gemm_2args_vm.cpp b/test/verify/gemm_2args_vm.cpp index 067ebc29d87..af6f870fee2 100644 --- a/test/verify/gemm_2args_vm.cpp +++ b/test/verify/gemm_2args_vm.cpp @@ -48,4 +48,5 @@ struct gemm_2args_vm : verify_program> }; template struct gemm_2args_vm; +template struct gemm_2args_vm; template struct gemm_2args_vm; diff --git a/test/verify/gemm_2args_vv.cpp b/test/verify/gemm_2args_vv.cpp index a42a154615d..c8a01140523 100644 --- a/test/verify/gemm_2args_vv.cpp +++ b/test/verify/gemm_2args_vv.cpp @@ -51,4 +51,5 @@ struct gemm_2args_vv : verify_program> }; template struct gemm_2args_vv; +template struct gemm_2args_vv; template struct gemm_2args_vv; diff --git a/test/verify/gemm_add.cpp b/test/verify/gemm_add.cpp index 041ddc2c94f..d75a05e0b7f 100644 --- a/test/verify/gemm_add.cpp +++ b/test/verify/gemm_add.cpp @@ -49,4 +49,5 @@ struct gemm_add : verify_program> }; template struct gemm_add; +template struct gemm_add; template struct gemm_add; diff --git a/test/verify/gemm_add_broadcast1.cpp b/test/verify/gemm_add_broadcast1.cpp index f153c4f9f69..290b0de50b8 100644 --- a/test/verify/gemm_add_broadcast1.cpp +++ b/test/verify/gemm_add_broadcast1.cpp @@ -51,4 +51,5 @@ struct gemm_add_broadcast1 : verify_program> }; template struct gemm_add_broadcast1; +template struct gemm_add_broadcast1; template struct gemm_add_broadcast1; diff --git a/test/verify/gemm_add_broadcast2.cpp b/test/verify/gemm_add_broadcast2.cpp index 5a7ff9bbe55..e35eefd9e59 100644 --- a/test/verify/gemm_add_broadcast2.cpp +++ b/test/verify/gemm_add_broadcast2.cpp @@ -51,4 +51,5 @@ struct gemm_add_broadcast2 : verify_program> }; template struct gemm_add_broadcast2; +template struct gemm_add_broadcast2; template struct gemm_add_broadcast2; diff --git a/test/verify/gemm_add_broadcast_half.cpp b/test/verify/gemm_add_broadcast_half.cpp deleted file mode 100644 index fb1918b1715..00000000000 --- a/test/verify/gemm_add_broadcast_half.cpp +++ /dev/null @@ -1,49 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include -#include -struct gemm_add_broadcast_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::half_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::half_type, {1, 3, 4}}; - migraphx::shape m3_shape{migraphx::shape::half_type, {1, 1, 4}}; - auto l1 = mm->add_parameter("1", m1_shape); - auto l2 = mm->add_parameter("2", m2_shape); - auto l3 = mm->add_parameter("3", m3_shape); - auto l3_b = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), l3); - - auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); - mm->add_instruction(migraphx::make_op("add"), dot, l3_b); - return p; - } -}; diff --git a/test/verify/gemm_add_half.cpp b/test/verify/gemm_add_half.cpp deleted file mode 100644 index 168fc853e6e..00000000000 --- a/test/verify/gemm_add_half.cpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include -#include -struct gemm_add_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::half_type, {1, 2, 3}}; - migraphx::shape m2_shape{migraphx::shape::half_type, {1, 3, 4}}; - migraphx::shape m3_shape{migraphx::shape::half_type, {1, 2, 4}}; - auto l1 = mm->add_parameter("1", m1_shape); - auto l2 = mm->add_parameter("2", m2_shape); - auto l3 = mm->add_parameter("3", m3_shape); - - auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); - mm->add_instruction(migraphx::make_op("add"), dot, l3); - return p; - } -}; diff --git a/test/verify/gemm_literal.cpp b/test/verify/gemm_literal.cpp index 41ff6cf62dc..f45384727ef 100644 --- a/test/verify/gemm_literal.cpp +++ b/test/verify/gemm_literal.cpp @@ -46,4 +46,5 @@ struct gemm_literal : verify_program> }; template struct gemm_literal; +template struct gemm_literal; template struct gemm_literal; diff --git a/test/verify/gemm_multi_3args.cpp b/test/verify/gemm_multi_3args.cpp index 1b5af2654ba..162ffbed25b 100644 --- a/test/verify/gemm_multi_3args.cpp +++ b/test/verify/gemm_multi_3args.cpp @@ -50,4 +50,5 @@ struct gemm_multi_3args : verify_program> }; template struct gemm_multi_3args; +template struct gemm_multi_3args; template struct gemm_multi_3args; diff --git a/test/verify/gemm_multi_3args_alpha0.cpp b/test/verify/gemm_multi_3args_alpha0.cpp index 638ec4adfc7..db3e5afc1d5 100644 --- a/test/verify/gemm_multi_3args_alpha0.cpp +++ b/test/verify/gemm_multi_3args_alpha0.cpp @@ -50,4 +50,5 @@ struct gemm_multi_3args_alpha0 : verify_program> }; template struct gemm_multi_3args_alpha0; +template struct gemm_multi_3args_alpha0; template struct gemm_multi_3args_alpha0; diff --git a/test/verify/gemm_multi_3args_beta0.cpp b/test/verify/gemm_multi_3args_beta0.cpp index efcf3263511..52e30251624 100644 --- a/test/verify/gemm_multi_3args_beta0.cpp +++ b/test/verify/gemm_multi_3args_beta0.cpp @@ -50,4 +50,5 @@ struct gemm_multi_3args_beta0 : verify_program> }; template struct gemm_multi_3args_beta0; +template struct gemm_multi_3args_beta0; template struct gemm_multi_3args_beta0; diff --git a/test/verify/gemm_multi_3args_c25.cpp b/test/verify/gemm_multi_3args_c25.cpp index 47a47c30af4..1006aad7f9f 100644 --- a/test/verify/gemm_multi_3args_c25.cpp +++ b/test/verify/gemm_multi_3args_c25.cpp @@ -50,4 +50,5 @@ struct gemm_multi_3args_c25 : verify_program> }; template struct gemm_multi_3args_c25; +template struct gemm_multi_3args_c25; template struct gemm_multi_3args_c25; diff --git a/test/verify/gemm_multi_dim_2.cpp b/test/verify/gemm_multi_dim_2.cpp index f2e5a50f666..ce825c4701b 100644 --- a/test/verify/gemm_multi_dim_2.cpp +++ b/test/verify/gemm_multi_dim_2.cpp @@ -46,4 +46,5 @@ struct gemm_multi_dim_2 : verify_program> }; template struct gemm_multi_dim_2; +template struct gemm_multi_dim_2; template struct gemm_multi_dim_2; diff --git a/test/verify/gemm_multi_dim_2_3.cpp b/test/verify/gemm_multi_dim_2_3.cpp index 70d6222b571..f3dd2579b5e 100644 --- a/test/verify/gemm_multi_dim_2_3.cpp +++ b/test/verify/gemm_multi_dim_2_3.cpp @@ -46,4 +46,5 @@ struct gemm_multi_dim_2_3 : verify_program> }; template struct gemm_multi_dim_2_3; +template struct gemm_multi_dim_2_3; template struct gemm_multi_dim_2_3; diff --git a/test/verify/gemm_multi_transpose.cpp b/test/verify/gemm_multi_transpose.cpp index d3d77698b9e..4c877e7942d 100644 --- a/test/verify/gemm_multi_transpose.cpp +++ b/test/verify/gemm_multi_transpose.cpp @@ -50,4 +50,5 @@ struct gemm_multi_transpose : verify_program> }; template struct gemm_multi_transpose; +template struct gemm_multi_transpose; template struct gemm_multi_transpose; diff --git a/test/verify/test_gemm.cpp b/test/verify/test_gemm.cpp index 38c2b7b154d..9374cc710b9 100644 --- a/test/verify/test_gemm.cpp +++ b/test/verify/test_gemm.cpp @@ -41,4 +41,5 @@ struct test_gemm : verify_program> }; template struct test_gemm; +template struct test_gemm; template struct test_gemm; diff --git a/test/verify/test_gemm_copy.cpp b/test/verify/test_gemm_copy.cpp index 360314ae2b1..b325bc8508b 100644 --- a/test/verify/test_gemm_copy.cpp +++ b/test/verify/test_gemm_copy.cpp @@ -49,4 +49,5 @@ struct test_gemm_copy : verify_program> }; template struct test_gemm_copy; +template struct test_gemm_copy; template struct test_gemm_copy; diff --git a/test/verify/test_gemm_ex.cpp b/test/verify/test_gemm_ex.cpp index 57fc2a79311..9f250dfa19f 100644 --- a/test/verify/test_gemm_ex.cpp +++ b/test/verify/test_gemm_ex.cpp @@ -41,4 +41,5 @@ struct test_gemm_ex : verify_program> } }; template struct test_gemm_ex; +template struct test_gemm_ex; template struct test_gemm_ex; diff --git a/test/verify/test_gemm_half.cpp b/test/verify/test_gemm_half.cpp deleted file mode 100644 index 602e978be50..00000000000 --- a/test/verify/test_gemm_half.cpp +++ /dev/null @@ -1,41 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include - -struct test_gemm_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::half_type, {4, 5}}); - auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::half_type, {5, 3}}); - mm->add_instruction(migraphx::make_op("dot"), a, b); - return p; - } -}; diff --git a/test/verify/test_gemm_transposea.cpp b/test/verify/test_gemm_transposea.cpp index 40922ee4f6f..345f3a7ee37 100644 --- a/test/verify/test_gemm_transposea.cpp +++ b/test/verify/test_gemm_transposea.cpp @@ -43,4 +43,5 @@ struct test_gemm_transposea : verify_program> }; template struct test_gemm_transposea; +template struct test_gemm_transposea; template struct test_gemm_transposea; diff --git a/test/verify/test_gemm_transposea_ex.cpp b/test/verify/test_gemm_transposea_ex.cpp index c3ffe7da253..49e33a856e1 100644 --- a/test/verify/test_gemm_transposea_ex.cpp +++ b/test/verify/test_gemm_transposea_ex.cpp @@ -44,4 +44,5 @@ struct test_gemm_transposea_ex : verify_program> }; template struct test_gemm_transposea_ex; +template struct test_gemm_transposea_ex; template struct test_gemm_transposea_ex; diff --git a/test/verify/test_gemm_transposeab.cpp b/test/verify/test_gemm_transposeab.cpp index 5f6d70dd9e5..5b800379cdc 100644 --- a/test/verify/test_gemm_transposeab.cpp +++ b/test/verify/test_gemm_transposeab.cpp @@ -44,4 +44,5 @@ struct test_gemm_transposeab : verify_program> }; template struct test_gemm_transposeab; +template struct test_gemm_transposeab; template struct test_gemm_transposeab; diff --git a/test/verify/test_gemm_transposeb.cpp b/test/verify/test_gemm_transposeb.cpp index 37455a5a18e..087d245f016 100644 --- a/test/verify/test_gemm_transposeb.cpp +++ b/test/verify/test_gemm_transposeb.cpp @@ -43,4 +43,5 @@ struct test_gemm_transposeb : verify_program> }; template struct test_gemm_transposeb; +template struct test_gemm_transposeb; template struct test_gemm_transposeb; diff --git a/test/verify/test_gemm_transposeb_ex.cpp b/test/verify/test_gemm_transposeb_ex.cpp index 8d8b305ba71..58c20a48275 100644 --- a/test/verify/test_gemm_transposeb_ex.cpp +++ b/test/verify/test_gemm_transposeb_ex.cpp @@ -44,4 +44,5 @@ struct test_gemm_transposeb_ex : verify_program> }; template struct test_gemm_transposeb_ex; +template struct test_gemm_transposeb_ex; template struct test_gemm_transposeb_ex; diff --git a/test/verify/test_mul_dot_a.cpp b/test/verify/test_mul_dot_a.cpp index ece43386d52..ae1ca6445c2 100644 --- a/test/verify/test_mul_dot_a.cpp +++ b/test/verify/test_mul_dot_a.cpp @@ -49,4 +49,5 @@ struct test_mul_dot_a : verify_program> }; template struct test_mul_dot_a; +template struct test_mul_dot_a; template struct test_mul_dot_a; diff --git a/test/verify/test_mul_dot_b.cpp b/test/verify/test_mul_dot_b.cpp index 34d36165797..3748576dd27 100644 --- a/test/verify/test_mul_dot_b.cpp +++ b/test/verify/test_mul_dot_b.cpp @@ -29,6 +29,7 @@ template struct test_mul_dot_b : verify_program> + { migraphx::program create_program() const { @@ -49,4 +50,5 @@ struct test_mul_dot_b : verify_program> }; template struct test_mul_dot_b; +template struct test_mul_dot_b; template struct test_mul_dot_b; diff --git a/test/verify/test_unbatched_gemm_1.cpp b/test/verify/test_unbatched_gemm_1.cpp index 7b40af94615..6bd36c691f1 100644 --- a/test/verify/test_unbatched_gemm_1.cpp +++ b/test/verify/test_unbatched_gemm_1.cpp @@ -60,4 +60,5 @@ struct test_unbatched_gemm_1 : verify_program> }; template struct test_unbatched_gemm_1; +template struct test_unbatched_gemm_1; template struct test_unbatched_gemm_1; diff --git a/test/verify/test_unbatched_gemm_2.cpp b/test/verify/test_unbatched_gemm_2.cpp index a2c16ffe26b..a27736c8d90 100644 --- a/test/verify/test_unbatched_gemm_2.cpp +++ b/test/verify/test_unbatched_gemm_2.cpp @@ -48,4 +48,5 @@ struct test_unbatched_gemm_2 : verify_program> }; template struct test_unbatched_gemm_2; +template struct test_unbatched_gemm_2; template struct test_unbatched_gemm_2; From ce61ea6b6aa6ff51b680b483f64bc0c61ec7e4ec Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sat, 2 Dec 2023 21:19:36 +0000 Subject: [PATCH 077/115] add quant_dot support for fp8 --- src/include/migraphx/op/quant_dot.hpp | 9 +++- src/simplify_reshapes.cpp | 5 ++ src/targets/gpu/gemm_impl.cpp | 9 ++-- src/targets/gpu/include/migraphx/gpu/gemm.hpp | 2 +- src/targets/ref/lowering.cpp | 50 +++++++++++++++---- test/verify/batch_quant_dot_1.cpp | 18 +++++-- test/verify/batch_quant_dot_2.cpp | 11 ++-- test/verify/batch_quant_dot_3.cpp | 9 ++-- test/verify/batch_quant_dot_4.cpp | 9 ++-- test/verify/batch_quant_dot_5.cpp | 9 ++-- test/verify/quant_dot_3args_1.cpp | 18 +++++-- test/verify/quant_dot_3args_2.cpp | 17 +++++-- test/verify/quant_dot_3args_3.cpp | 16 ++++-- test/verify/quant_dot_3args_4.cpp | 17 +++++-- test/verify/quant_dot_3args_5.cpp | 14 ++++-- 15 files changed, 151 insertions(+), 62 deletions(-) diff --git a/src/include/migraphx/op/quant_dot.hpp b/src/include/migraphx/op/quant_dot.hpp index 6289adae534..1cc9acc70dc 100644 --- a/src/include/migraphx/op/quant_dot.hpp +++ b/src/include/migraphx/op/quant_dot.hpp @@ -44,9 +44,10 @@ struct quant_dot const shape& a = inputs.at(0); const shape& b = inputs.at(1); auto t = a.type(); - if(t != shape::int8_type) + std::set suppported_types = {shape::int8_type, shape::fp8e4m3fnuz_type}; + if(not contains(suppported_types, t)) { - MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t"); + MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type"); } if(not std::all_of( @@ -73,6 +74,10 @@ struct quant_dot auto out_lens = a.lens(); out_lens[dim_1] = b.lens()[dim_1]; + if(t == shape::fp8e4m3fnuz_type) + { + return {shape::float_type, out_lens}; + } // else int8 gemm return {shape::int32_type, out_lens}; } }; diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 0dc093026a3..7b5479cf522 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -183,6 +183,11 @@ struct find_nested_convert auto x = ins->inputs().front(); auto input = x->inputs().front(); + while(input->name() == "convert") + { + input = input->inputs().front(); + } + if(ins->get_shape() != input->get_shape()) return; diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 057e81395af..3c21602d89a 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -180,12 +180,9 @@ struct gemm_impl ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc; arg_type = get_type(input_shapes[0].type()); - output_type = arg_type; - if(output_type == rocblas_datatype_i8_r) - { - output_type = rocblas_datatype_i32_r; - } - compute_type = output_type; + output_type = get_type(input_shapes[2].type()); + compute_type = + output_type; // not valid for ex3 BETA APIs. it has different type and set differently. if(compute_fp32) { if(arg_type == rocblas_datatype_f16_r) diff --git a/src/targets/gpu/include/migraphx/gpu/gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm.hpp index bd9b0eefa14..056fc175c1b 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm.hpp @@ -112,7 +112,7 @@ struct rocblas_gemm argument compute(context& ctx, const shape& output_shape, const std::vector& args) const { - if(this->name() == "gpu::gemm") + if(this->name() == "gpu::gemm" or output_shape.type() == migraphx::shape::float_type) { gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx); } diff --git a/src/targets/ref/lowering.cpp b/src/targets/ref/lowering.cpp index eb1e20fe369..12026e9cd9b 100644 --- a/src/targets/ref/lowering.cpp +++ b/src/targets/ref/lowering.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -307,19 +308,46 @@ struct ref_quant_gemm { argument result{output_shape}; // first, convert the args[0] and args[1] from int8_t to int32_t - argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}}; - argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}}; - arg_0.visit([&](auto output) { - args.at(0).visit( - [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); - }); + argument arg_0{{output_shape.type(), {args.at(0).get_shape().lens()}}}; + argument arg_1{{output_shape.type(), {args.at(1).get_shape().lens()}}}; + if(output_shape.type() == migraphx::shape::float_type) + { + arg_0.visit([&](auto output) { + args.at(0).visit([&](auto input) { + std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { + return static_cast(x); + }); + }); + }); - arg_1.visit([&](auto output) { - args.at(1).visit( - [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); - }); + arg_1.visit([&](auto output) { + args.at(1).visit([&](auto input) { + std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { + return static_cast(x); + }); + }); + }); + migemm(result, arg_0, arg_1, 1.0f, 0.0f); + } + else if(output_shape.type() == migraphx::shape::int32_type) + { + arg_0.visit([&](auto output) { + args.at(0).visit([&](auto input) { + std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { + return static_cast(x); + }); + }); + }); - migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0}); + arg_1.visit([&](auto output) { + args.at(1).visit([&](auto input) { + std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { + return static_cast(x); + }); + }); + }); + migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0}); + } return result; } diff --git a/test/verify/batch_quant_dot_1.cpp b/test/verify/batch_quant_dot_1.cpp index 28a10f81287..3a7e488a40a 100644 --- a/test/verify/batch_quant_dot_1.cpp +++ b/test/verify/batch_quant_dot_1.cpp @@ -24,19 +24,23 @@ #include "verify_program.hpp" #include +#include #include #include #include -struct batch_quant_dot_1 : verify_program +template +struct batch_quant_dot_1 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; + auto dtype = migraphx::shape::get_type{}; + auto ctype = migraphx::shape::get_type{}; + migraphx::shape m1_shape{dtype, {3, 2, 8, 2}}; + migraphx::shape m2_shape{dtype, {3, 2, 7, 8}}; + migraphx::shape m3_shape{ctype, {3, 2, 2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto tl1 = mm->add_instruction( @@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program auto tl2 = mm->add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); + migraphx::add_apply_alpha_beta( + *mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2}); return p; } }; + +template struct batch_quant_dot_1; +template struct batch_quant_dot_1; diff --git a/test/verify/batch_quant_dot_2.cpp b/test/verify/batch_quant_dot_2.cpp index 241cac71a39..3a1b2004f16 100644 --- a/test/verify/batch_quant_dot_2.cpp +++ b/test/verify/batch_quant_dot_2.cpp @@ -28,15 +28,16 @@ #include #include -struct batch_quant_dot_2 : verify_program +template +struct batch_quant_dot_2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; + migraphx::shape m1_shape{DType, {3, 2, 2, 8}}; + migraphx::shape m2_shape{DType, {3, 2, 8, 7}}; + migraphx::shape m3_shape{CType, {3, 2, 2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); @@ -45,3 +46,5 @@ struct batch_quant_dot_2 : verify_program return p; } }; +template struct batch_quant_dot_2; +template struct batch_quant_dot_2; diff --git a/test/verify/batch_quant_dot_3.cpp b/test/verify/batch_quant_dot_3.cpp index 05bcc1420f6..8c861db01dc 100644 --- a/test/verify/batch_quant_dot_3.cpp +++ b/test/verify/batch_quant_dot_3.cpp @@ -27,14 +27,15 @@ #include #include -struct batch_quant_dot_3 : verify_program +template +struct batch_quant_dot_3 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}}; + migraphx::shape m1_shape{DType, {3, 2, 2, 6}}; + migraphx::shape m2_shape{DType, {3, 2, 6, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); @@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program return p; } }; +template struct batch_quant_dot_3; +template struct batch_quant_dot_3; diff --git a/test/verify/batch_quant_dot_4.cpp b/test/verify/batch_quant_dot_4.cpp index 7865b9e46e8..230a1988321 100644 --- a/test/verify/batch_quant_dot_4.cpp +++ b/test/verify/batch_quant_dot_4.cpp @@ -27,14 +27,15 @@ #include #include -struct batch_quant_dot_4 : verify_program +template +struct batch_quant_dot_4 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}}; + migraphx::shape m1_shape{DType, {2, 4, 6, 3}}; + migraphx::shape m2_shape{DType, {7, 2, 6, 3}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); @@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program return p; } }; +template struct batch_quant_dot_4; +template struct batch_quant_dot_4; diff --git a/test/verify/batch_quant_dot_5.cpp b/test/verify/batch_quant_dot_5.cpp index 5f5ba073183..78426615c73 100644 --- a/test/verify/batch_quant_dot_5.cpp +++ b/test/verify/batch_quant_dot_5.cpp @@ -27,14 +27,15 @@ #include #include -struct batch_quant_dot_5 : verify_program +template +struct batch_quant_dot_5 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}}; + migraphx::shape m1_shape{DType, {3, 2, 7, 2}}; + migraphx::shape m2_shape{DType, {3, 2, 5, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); @@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program return p; } }; +template struct batch_quant_dot_5; +template struct batch_quant_dot_5; diff --git a/test/verify/quant_dot_3args_1.cpp b/test/verify/quant_dot_3args_1.cpp index c233e4a22bb..ab45e9ece72 100644 --- a/test/verify/quant_dot_3args_1.cpp +++ b/test/verify/quant_dot_3args_1.cpp @@ -25,23 +25,31 @@ #include "verify_program.hpp" #include #include +#include #include #include -struct quant_dot_3args_1 : verify_program +template +struct quant_dot_3args_1 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; + auto ctype = migraphx::shape::get_type(); + auto dtype = migraphx::shape::get_type(); + migraphx::shape m1_shape{dtype, {2, 8}}; + migraphx::shape m2_shape{dtype, {8, 7}}; + migraphx::shape m3_shape{ctype, {2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1); + migraphx::add_apply_alpha_beta( + *mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{1}); return p; } }; + +template struct quant_dot_3args_1; +template struct quant_dot_3args_1; diff --git a/test/verify/quant_dot_3args_2.cpp b/test/verify/quant_dot_3args_2.cpp index b546e5194e8..5037960bb61 100644 --- a/test/verify/quant_dot_3args_2.cpp +++ b/test/verify/quant_dot_3args_2.cpp @@ -28,22 +28,29 @@ #include #include -struct quant_dot_3args_2 : verify_program +template +struct quant_dot_3args_2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; + auto ctype = migraphx::shape::get_type(); + auto dtype = migraphx::shape::get_type(); + migraphx::shape m1_shape{dtype, {8, 2}}; + migraphx::shape m2_shape{dtype, {8, 7}}; + migraphx::shape m3_shape{ctype, {2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); auto l2 = mm->add_parameter("b", m2_shape); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); + migraphx::add_apply_alpha_beta( + *mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3}); return p; } }; + +template struct quant_dot_3args_2; +template struct quant_dot_3args_2; diff --git a/test/verify/quant_dot_3args_3.cpp b/test/verify/quant_dot_3args_3.cpp index 12ba110eb96..2c0bcd3f2ea 100644 --- a/test/verify/quant_dot_3args_3.cpp +++ b/test/verify/quant_dot_3args_3.cpp @@ -28,22 +28,28 @@ #include #include -struct quant_dot_3args_3 : verify_program +template +struct quant_dot_3args_3 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; + auto ctype = migraphx::shape::get_type(); + auto dtype = migraphx::shape::get_type(); + migraphx::shape m1_shape{dtype, {2, 8}}; + migraphx::shape m2_shape{dtype, {7, 8}}; + migraphx::shape m3_shape{ctype, {2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3); + migraphx::add_apply_alpha_beta( + *mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), CType{2}, CType{3}); return p; } }; +template struct quant_dot_3args_3; +template struct quant_dot_3args_3; diff --git a/test/verify/quant_dot_3args_4.cpp b/test/verify/quant_dot_3args_4.cpp index cc559be70b4..9872c76d00d 100644 --- a/test/verify/quant_dot_3args_4.cpp +++ b/test/verify/quant_dot_3args_4.cpp @@ -28,15 +28,18 @@ #include #include -struct quant_dot_3args_4 : verify_program +template +struct quant_dot_3args_4 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; - migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; + auto ctype = migraphx::shape::get_type(); + auto dtype = migraphx::shape::get_type(); + migraphx::shape m1_shape{dtype, {8, 2}}; + migraphx::shape m2_shape{dtype, {7, 8}}; + migraphx::shape m3_shape{ctype, {2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto tl1 = @@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); + migraphx::add_apply_alpha_beta( + *mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2}); return p; } }; + +template struct quant_dot_3args_4; +template struct quant_dot_3args_4; diff --git a/test/verify/quant_dot_3args_5.cpp b/test/verify/quant_dot_3args_5.cpp index 120487e93f3..7d3926981ea 100644 --- a/test/verify/quant_dot_3args_5.cpp +++ b/test/verify/quant_dot_3args_5.cpp @@ -28,14 +28,17 @@ #include #include -struct quant_dot_3args_5 : verify_program +template +struct quant_dot_3args_5 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}}; - migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}}; + auto dtype = migraphx::shape::get_type(); + + migraphx::shape m1_shape{dtype, {6, 2}}; + migraphx::shape m2_shape{dtype, {7, 6}}; auto l1 = mm->add_parameter("a", m1_shape); auto tl1 = @@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program auto l2 = mm->add_parameter("b", m2_shape); auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); - migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), CType{3}); return p; } }; + +template struct quant_dot_3args_5; +template struct quant_dot_3args_5; From 575fc04a392cf02521d36c5912a6a53df853e44a Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 26 Nov 2023 23:22:05 +0000 Subject: [PATCH 078/115] mlir fp8 --- src/targets/gpu/fuse_ck.cpp | 3 ++- src/targets/gpu/fuse_mlir.cpp | 7 +++++-- src/targets/gpu/mlir.cpp | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index 7043985573b..4b0ad4c7d51 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -69,7 +69,8 @@ struct ck_gemm static bool is_ck_supported_type(shape::type_t t) { - return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); + return contains( + {shape::half_type, shape::int8_type, shape::int32_type, shape::fp8e4m3fnuz_type}, t); } }; MIGRAPHX_REGISTER_OP(ck_gemm); diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 7e08a901b00..65c697a8f68 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -223,6 +223,8 @@ auto is_mlir_conv(mlir_mode mode) return false; if(ins->get_shape().type() == shape::int8_type) return true; + if(ins->get_shape().type() == shape::fp8e4m3fnuz_type) + return true; if(mode == mlir_mode::int8) return false; if(mode == mlir_mode::all) @@ -288,6 +290,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) const auto result_type = i.get_shape().type(); const std::initializer_list allowed_types = {type_t::float_type, type_t::half_type, + type_t::fp8e4m3fnuz_type, type_t::int8_type, type_t::int32_type, type_t::bool_type}; @@ -327,7 +330,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) "softmax", "tanh", }; - bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); + bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type); if(contains(any_type_ops, name)) return true; if(result_type != type_t::bool_type and contains(no_bool_ops, name)) @@ -404,7 +407,7 @@ struct find_mlir_standalone_op // enable only for fp32/fp16/i8 types if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { return not contains( - {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, + {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type, shape::type_t::fp8e4m3fnuz_type}, i->get_shape().type()); })) return; diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 138d24bb136..77a16a4ffef 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -300,6 +300,8 @@ struct mlir_program result = mlirF32TypeGet(ctx.get()); else if(as.type_enum() == shape::half_type) result = mlirF16TypeGet(ctx.get()); + else if(as.type_enum() == shape::fp8e4m3fnuz_type) + result = mlirFloat8E4M3FNUZTypeGet(ctx.get()); else if(as.type_enum() == shape::double_type) result = mlirF64TypeGet(ctx.get()); else if(as.is_integral()) From afb55bdf0bdfad0dee5e75a3f370293ec954c48b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 27 Nov 2023 13:56:12 +0000 Subject: [PATCH 079/115] add some MLIR fp8 tests for convolutions --- .../verify/test_conv_add_1x1_diff_strides.cpp | 16 ++++++++------- test/verify/test_conv_bn.cpp | 16 +++++++++------ test/verify/test_conv_bn_relu_pooling.cpp | 16 +++++++++------ test/verify/test_conv_bn_relu_pooling2.cpp | 20 +++++++++++-------- 4 files changed, 41 insertions(+), 27 deletions(-) diff --git a/test/verify/test_conv_add_1x1_diff_strides.cpp b/test/verify/test_conv_add_1x1_diff_strides.cpp index 9e2be95966d..c07467fa99f 100644 --- a/test/verify/test_conv_add_1x1_diff_strides.cpp +++ b/test/verify/test_conv_add_1x1_diff_strides.cpp @@ -27,18 +27,17 @@ #include #include -struct test_conv_add_1x1_diff_strides : verify_program +template +struct test_conv_add_1x1_diff_strides : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}}); - auto w = mm->add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1)); - auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); - auto v = mm->add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2)); + auto x = mm->add_parameter("x", {DType, {1, 8, 2, 2}}); + auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 1)); + auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}}); + auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 2)); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv2 = mm->add_instruction( migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v); @@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program; +template struct test_conv_add_1x1_diff_strides; diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp index cda424de5c1..510dd6da5ad 100644 --- a/test/verify/test_conv_bn.cpp +++ b/test/verify/test_conv_bn.cpp @@ -29,16 +29,17 @@ #include #include -struct test_conv_bn : verify_program +template +struct test_conv_bn : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; - migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; - migraphx::shape vars{migraphx::shape::float_type, {64}}; + migraphx::shape xs{DType, {1, 3, 224, 224}}; + migraphx::shape ws{DType, {64, 3, 7, 7}}; + migraphx::shape vars{DType, {64}}; auto x = mm->add_parameter("x", xs); auto w = mm->add_parameter("w", ws); // non-symmetrical tiling @@ -53,8 +54,8 @@ struct test_conv_bn : verify_program auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); - auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); + auto rt = mm->add_literal(migraphx::literal{DType, {0.5}}); + auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); @@ -74,3 +75,6 @@ struct test_conv_bn : verify_program return p; } }; + +template struct test_conv_bn; +template struct test_conv_bn; diff --git a/test/verify/test_conv_bn_relu_pooling.cpp b/test/verify/test_conv_bn_relu_pooling.cpp index 4d4c8abb1f4..6ee0e67dd9b 100644 --- a/test/verify/test_conv_bn_relu_pooling.cpp +++ b/test/verify/test_conv_bn_relu_pooling.cpp @@ -30,16 +30,17 @@ #include #include -struct test_conv_bn_relu_pooling : verify_program +template +struct test_conv_bn_relu_pooling : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; - migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; - migraphx::shape vars{migraphx::shape::float_type, {64}}; + migraphx::shape xs{DType, {1, 3, 224, 224}}; + migraphx::shape ws{DType, {64, 3, 7, 7}}; + migraphx::shape vars{DType, {64}}; auto x = mm->add_parameter("x", xs); auto w = mm->add_parameter("w", ws); auto conv = mm->add_instruction( @@ -52,8 +53,8 @@ struct test_conv_bn_relu_pooling : verify_program auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); - auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); + auto rt = mm->add_literal(migraphx::literal{DType, {0.5}}); + auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); @@ -82,3 +83,6 @@ struct test_conv_bn_relu_pooling : verify_program return p; } }; + +template struct test_conv_bn_relu_pooling; +template struct test_conv_bn_relu_pooling; diff --git a/test/verify/test_conv_bn_relu_pooling2.cpp b/test/verify/test_conv_bn_relu_pooling2.cpp index 39abacd7c28..e38434cd1bb 100644 --- a/test/verify/test_conv_bn_relu_pooling2.cpp +++ b/test/verify/test_conv_bn_relu_pooling2.cpp @@ -30,21 +30,22 @@ #include #include -struct test_conv_bn_relu_pooling2 : verify_program +template +struct test_conv_bn_relu_pooling2 : verify_program> { static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x) { auto bn_lens = x->get_shape().lens(); auto c_len = bn_lens.at(1); - migraphx::shape vars{migraphx::shape::float_type, {c_len}}; + migraphx::shape vars{DType, {c_len}}; auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len))); auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len))); auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len))); auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len))); - auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); - auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); + auto rt = m.add_literal(migraphx::literal{DType, {0.5}}); + auto eps = m.add_literal(migraphx::literal{DType, {1e-5f}}); auto usq_scale = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); @@ -66,10 +67,10 @@ struct test_conv_bn_relu_pooling2 : verify_program migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape xs1{migraphx::shape::float_type, {1, 512, 7, 7}}; - migraphx::shape xs2{migraphx::shape::float_type, {1, 1024, 14, 14}}; - migraphx::shape ws1{migraphx::shape::float_type, {2048, 512, 1, 1}}; - migraphx::shape ws2{migraphx::shape::float_type, {2048, 1024, 1, 1}}; + migraphx::shape xs1{DType, {1, 512, 7, 7}}; + migraphx::shape xs2{DType, {1, 1024, 14, 14}}; + migraphx::shape ws1{DType, {2048, 512, 1, 1}}; + migraphx::shape ws2{DType, {2048, 1024, 1, 1}}; auto x1 = mm->add_parameter("x1", xs1); auto w1 = mm->add_parameter("w1", ws1); auto conv1 = mm->add_instruction( @@ -98,3 +99,6 @@ struct test_conv_bn_relu_pooling2 : verify_program return p; } }; + +template struct test_conv_bn_relu_pooling2; +template struct test_conv_bn_relu_pooling2; From f2931937540604baa4871d271949c701278aa4b4 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 28 Nov 2023 01:43:14 +0000 Subject: [PATCH 080/115] small example for fp8 fail case --- test/verify/test_conv_bn.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp index 510dd6da5ad..0df17f0d843 100644 --- a/test/verify/test_conv_bn.cpp +++ b/test/verify/test_conv_bn.cpp @@ -43,11 +43,12 @@ struct test_conv_bn : verify_program> auto x = mm->add_parameter("x", xs); auto w = mm->add_parameter("w", ws); // non-symmetrical tiling - auto conv = mm->add_instruction( - migraphx::make_op("convolution", - {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), - x, - w); + // auto conv = mm->add_instruction( + // migraphx::make_op("convolution", + // {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + // x, + // w); + auto conv = mm->add_parameter("conv", migraphx::shape{DType, {1, 64, 112, 112}}); auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); From a8ef91245360ce91cac42b450c3c6873b04d1361 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 28 Nov 2023 14:34:04 +0000 Subject: [PATCH 081/115] add test for conv_bn with 1e-1f --- test/verify/test_conv_bn.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp index 0df17f0d843..ea9e03d5edb 100644 --- a/test/verify/test_conv_bn.cpp +++ b/test/verify/test_conv_bn.cpp @@ -43,12 +43,11 @@ struct test_conv_bn : verify_program> auto x = mm->add_parameter("x", xs); auto w = mm->add_parameter("w", ws); // non-symmetrical tiling - // auto conv = mm->add_instruction( - // migraphx::make_op("convolution", - // {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), - // x, - // w); - auto conv = mm->add_parameter("conv", migraphx::shape{DType, {1, 64, 112, 112}}); + auto conv = mm->add_instruction( + migraphx::make_op("convolution", + {{"padding", {3, 3}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))); @@ -56,7 +55,12 @@ struct test_conv_bn : verify_program> auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto rt = mm->add_literal(migraphx::literal{DType, {0.5}}); + auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); + if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) + { + eps = mm->add_literal(migraphx::literal{DType, {1e-1f}}); + } auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); From 32e08558229192b02ee2f47fff84fb30929c6322 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 28 Nov 2023 14:38:07 +0000 Subject: [PATCH 082/115] fix conv_bn eps --- test/verify/test_conv_bn.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp index ea9e03d5edb..10447abc3a2 100644 --- a/test/verify/test_conv_bn.cpp +++ b/test/verify/test_conv_bn.cpp @@ -59,7 +59,8 @@ struct test_conv_bn : verify_program> auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) { - eps = mm->add_literal(migraphx::literal{DType, {1e-1f}}); + // use 1e-2f for the fp8 + eps = mm->add_literal(migraphx::literal{DType, {1e-2f}}); } auto usq_scale = From f18418be9024bb73794d6aca173dc016403e4666 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 28 Nov 2023 14:43:45 +0000 Subject: [PATCH 083/115] add pooling to unsupported ops --- src/targets/gpu/target.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index d6aadf52569..0610a745900 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -111,6 +111,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti { unsupported_fp8_ops.insert("dot"); } + unsupported_fp8_ops.insert("pooling"); // clang-format off return { From 9acd36ad7dd47224720c8658b93e3537d16e6c81 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 28 Nov 2023 14:58:34 +0000 Subject: [PATCH 084/115] update eps --- test/verify/test_conv_bn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp index 10447abc3a2..76844b475ef 100644 --- a/test/verify/test_conv_bn.cpp +++ b/test/verify/test_conv_bn.cpp @@ -60,7 +60,7 @@ struct test_conv_bn : verify_program> if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) { // use 1e-2f for the fp8 - eps = mm->add_literal(migraphx::literal{DType, {1e-2f}}); + eps = mm->add_literal(migraphx::literal{DType, {5e-2f}}); } auto usq_scale = From c31411990112ae17a4ebf8b9b382f0f6034101f0 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 28 Nov 2023 14:58:48 +0000 Subject: [PATCH 085/115] update eps --- test/verify/test_conv_bn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp index 76844b475ef..150ee68c0df 100644 --- a/test/verify/test_conv_bn.cpp +++ b/test/verify/test_conv_bn.cpp @@ -59,7 +59,7 @@ struct test_conv_bn : verify_program> auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) { - // use 1e-2f for the fp8 + // use 5e-2f for the fp8 eps = mm->add_literal(migraphx::literal{DType, {5e-2f}}); } From 88eb355979ed2c4270137f5a81b311eea63114dd Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 28 Nov 2023 16:39:55 +0000 Subject: [PATCH 086/115] add conv tests supported by MLIR --- src/targets/gpu/fuse_mlir.cpp | 4 +-- test/verify/test_conv.cpp | 12 +++++---- test/verify/test_conv2.cpp | 9 ++++--- test/verify/test_conv_add.cpp | 16 ++++++----- test/verify/test_conv_add_relu.cpp | 12 ++++++--- test/verify/test_conv_bias_clipped_relu.cpp | 19 ++++++------- test/verify/test_conv_bn_add.cpp | 30 +++++++++++++-------- test/verify/test_conv_bn_relu_pooling.cpp | 6 ++++- test/verify/test_conv_bn_relu_pooling2.cpp | 6 ++++- test/verify/test_conv_group_add.cpp | 11 +++++--- test/verify/test_conv_pooling.cpp | 10 ++++--- test/verify/test_conv_relu.cpp | 11 ++++---- 12 files changed, 91 insertions(+), 55 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 65c697a8f68..d407c249e9f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -221,10 +221,10 @@ auto is_mlir_conv(mlir_mode mode) // Avoid MLIR assertion: Index < Length && "Invalid index!" if(ins->get_shape().lens().size() != 4) return false; - if(ins->get_shape().type() == shape::int8_type) - return true; if(ins->get_shape().type() == shape::fp8e4m3fnuz_type) return true; + if(ins->get_shape().type() == shape::int8_type) + return true; if(mode == mlir_mode::int8) return false; if(mode == mlir_mode::all) diff --git a/test/verify/test_conv.cpp b/test/verify/test_conv.cpp index 873016bb5a6..118048f3f81 100644 --- a/test/verify/test_conv.cpp +++ b/test/verify/test_conv.cpp @@ -27,17 +27,19 @@ #include #include -struct test_conv : verify_program +template +struct test_conv : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); mm->add_instruction(migraphx::make_op("convolution"), input, weights); return p; } }; + +template struct test_conv; +template struct test_conv; diff --git a/test/verify/test_conv2.cpp b/test/verify/test_conv2.cpp index e6dea116f20..a3d2a123868 100644 --- a/test/verify/test_conv2.cpp +++ b/test/verify/test_conv2.cpp @@ -27,16 +27,17 @@ #include #include -struct test_conv2 : verify_program +template +struct test_conv2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}}); + mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}}); auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}}); + mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}}); mm->add_instruction( migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), @@ -45,3 +46,5 @@ struct test_conv2 : verify_program return p; } }; +template struct test_conv2; +template struct test_conv2; diff --git a/test/verify/test_conv_add.cpp b/test/verify/test_conv_add.cpp index 934a1985709..751c5798156 100644 --- a/test/verify/test_conv_add.cpp +++ b/test/verify/test_conv_add.cpp @@ -27,18 +27,17 @@ #include #include -struct test_conv_add : verify_program +template +struct test_conv_add : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); - auto w = mm->add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 1)); - auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); - auto v = mm->add_literal( - migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2)); + auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}}); + auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1)); + auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}}); + auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2)); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v); auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2); @@ -46,3 +45,6 @@ struct test_conv_add : verify_program return p; } }; + +template struct test_conv_add; +template struct test_conv_add; diff --git a/test/verify/test_conv_add_relu.cpp b/test/verify/test_conv_add_relu.cpp index 74533c86a13..69e60792ad0 100644 --- a/test/verify/test_conv_add_relu.cpp +++ b/test/verify/test_conv_add_relu.cpp @@ -28,17 +28,18 @@ #include #include -struct test_conv_add_relu : verify_program +template +struct test_conv_add_relu : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto bias_literal = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}}, + mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto bias_literal = migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; auto bias = mm->add_literal(bias_literal); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); @@ -50,3 +51,6 @@ struct test_conv_add_relu : verify_program return p; } }; + +template struct test_conv_add_relu; +template struct test_conv_add_relu; diff --git a/test/verify/test_conv_bias_clipped_relu.cpp b/test/verify/test_conv_bias_clipped_relu.cpp index bd9fc3bff07..28844dbfb55 100644 --- a/test/verify/test_conv_bias_clipped_relu.cpp +++ b/test/verify/test_conv_bias_clipped_relu.cpp @@ -29,26 +29,24 @@ #include -struct test_conv_bias_clipped_relu : verify_program +template +struct test_conv_bias_clipped_relu : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto l0 = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}}, - {2.0f, 2.0f, 2.0f, 2.0f}}; + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto l0 = migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; auto bias = mm->add_literal(l0); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto bcast_add = mm->add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), bias); auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_add); - auto min_val = mm->add_literal(0.0f); - auto max_val = mm->add_literal(6.0f); + auto min_val = mm->add_literal(migraphx::literal(DType, {0.0f})); + auto max_val = mm->add_literal(migraphx::literal(DType, {6.0f})); min_val = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), min_val); max_val = mm->add_instruction( @@ -57,3 +55,6 @@ struct test_conv_bias_clipped_relu : verify_program return p; } }; + +template struct test_conv_bias_clipped_relu; +template struct test_conv_bias_clipped_relu; diff --git a/test/verify/test_conv_bn_add.cpp b/test/verify/test_conv_bn_add.cpp index 3433314ecad..99319f44b28 100644 --- a/test/verify/test_conv_bn_add.cpp +++ b/test/verify/test_conv_bn_add.cpp @@ -29,22 +29,27 @@ #include #include -struct test_conv_bn_add : verify_program +template +struct test_conv_bn_add : verify_program> { static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x) { auto bn_lens = x->get_shape().lens(); auto c_len = bn_lens.at(1); - migraphx::shape vars{migraphx::shape::float_type, {c_len}}; + migraphx::shape vars{DType, {c_len}}; auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len))); auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len))); auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len))); auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len))); - auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); - auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); - + auto rt = m.add_literal(migraphx::literal{DType, {0.5}}); + auto eps = m.add_literal(migraphx::literal{DType, {1e-5f}}); + if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) + { + // use 5e-2f for the fp8 + eps = m.add_literal(migraphx::literal{DType, {5e-2f}}); + } auto usq_scale = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); @@ -66,12 +71,12 @@ struct test_conv_bn_add : verify_program auto* mm = p.get_main_module(); std::size_t ichannels = 64; std::size_t ochannels = 256; - auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); - auto w = mm->add_literal(migraphx::generate_literal( - {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1)); - auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); - auto v = mm->add_literal(migraphx::generate_literal( - {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2)); + auto x = mm->add_parameter("x", {DType, {1, ichannels, 56, 56}}); + auto w = + mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 1)); + auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}}); + auto v = + mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 2)); auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w); auto bn1 = add_bn(*mm, conv1); @@ -83,3 +88,6 @@ struct test_conv_bn_add : verify_program return p; } }; + +template struct test_conv_bn_add; +template struct test_conv_bn_add; diff --git a/test/verify/test_conv_bn_relu_pooling.cpp b/test/verify/test_conv_bn_relu_pooling.cpp index 6ee0e67dd9b..4b283779c45 100644 --- a/test/verify/test_conv_bn_relu_pooling.cpp +++ b/test/verify/test_conv_bn_relu_pooling.cpp @@ -55,7 +55,11 @@ struct test_conv_bn_relu_pooling : verify_programadd_literal(migraphx::literal{DType, {0.5}}); auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); - + if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) + { + // use 5e-2f for the fp8 + eps = mm->add_literal(migraphx::literal{DType, {5e-2f}}); + } auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_bias = diff --git a/test/verify/test_conv_bn_relu_pooling2.cpp b/test/verify/test_conv_bn_relu_pooling2.cpp index e38434cd1bb..3bf9e907d97 100644 --- a/test/verify/test_conv_bn_relu_pooling2.cpp +++ b/test/verify/test_conv_bn_relu_pooling2.cpp @@ -46,7 +46,11 @@ struct test_conv_bn_relu_pooling2 : verify_program #include -struct test_conv_group_add : verify_program +template +struct test_conv_group_add : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {1, 68, 28, 28}}; + migraphx::shape s{DType, {1, 68, 28, 28}}; auto x = mm->add_parameter("x", s); - auto w = mm->add_parameter("w", {migraphx::shape::float_type, {68, 17, 1, 1}}); - auto b = mm->add_parameter("b", {migraphx::shape::float_type, {68}}); + auto w = mm->add_parameter("w", {DType, {68, 17, 1, 1}}); + auto b = mm->add_parameter("b", {DType, {68}}); auto conv = mm->add_instruction(migraphx::make_op("convolution", {{"group", 4}}), x, w); auto bb = mm->add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 68, 28, 28}}}), b); @@ -44,3 +45,5 @@ struct test_conv_group_add : verify_program return p; } }; +template struct test_conv_group_add; +// template struct test_conv_group_add; diff --git a/test/verify/test_conv_pooling.cpp b/test/verify/test_conv_pooling.cpp index d4e7b7b66af..d12c81a21ef 100644 --- a/test/verify/test_conv_pooling.cpp +++ b/test/verify/test_conv_pooling.cpp @@ -28,16 +28,17 @@ #include #include -struct test_conv_pooling : verify_program +template +struct test_conv_pooling : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 32, 32}}); + mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}}); auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto pooling = mm->add_instruction( migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv); @@ -45,3 +46,6 @@ struct test_conv_pooling : verify_program return p; } }; + +template struct test_conv_pooling; +template struct test_conv_pooling; diff --git a/test/verify/test_conv_relu.cpp b/test/verify/test_conv_relu.cpp index 312cac4f6a5..1ed2613cf6a 100644 --- a/test/verify/test_conv_relu.cpp +++ b/test/verify/test_conv_relu.cpp @@ -27,18 +27,19 @@ #include #include -struct test_conv_relu : verify_program +template +struct test_conv_relu : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); mm->add_instruction(migraphx::make_op("relu"), conv); return p; } }; +template struct test_conv_relu; +template struct test_conv_relu; From 3f213325293a62a2201a190f02e76bb3534225fd Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 14:26:58 +0000 Subject: [PATCH 087/115] remove half test and add it as template --- test/verify/test_conv_relu.cpp | 1 + test/verify/test_conv_relu_half.cpp | 44 ----------------------------- 2 files changed, 1 insertion(+), 44 deletions(-) delete mode 100644 test/verify/test_conv_relu_half.cpp diff --git a/test/verify/test_conv_relu.cpp b/test/verify/test_conv_relu.cpp index 1ed2613cf6a..2d41deed401 100644 --- a/test/verify/test_conv_relu.cpp +++ b/test/verify/test_conv_relu.cpp @@ -42,4 +42,5 @@ struct test_conv_relu : verify_program> } }; template struct test_conv_relu; +template struct test_conv_relu; template struct test_conv_relu; diff --git a/test/verify/test_conv_relu_half.cpp b/test/verify/test_conv_relu_half.cpp deleted file mode 100644 index a61865c7347..00000000000 --- a/test/verify/test_conv_relu_half.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include - -struct test_conv_relu_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); - auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); - mm->add_instruction(migraphx::make_op("relu"), conv); - return p; - } -}; From 050184cb25b57c7b1f6db78f35eb3a2c9a889baf Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 14:29:31 +0000 Subject: [PATCH 088/115] revert some changes --- src/include/migraphx/op/quant_dot.hpp | 9 +--- src/simplify_reshapes.cpp | 5 -- src/targets/gpu/fuse_ck.cpp | 3 +- src/targets/gpu/gemm_impl.cpp | 9 ++-- src/targets/gpu/include/migraphx/gpu/gemm.hpp | 2 +- src/targets/ref/lowering.cpp | 50 ++++--------------- 6 files changed, 21 insertions(+), 57 deletions(-) diff --git a/src/include/migraphx/op/quant_dot.hpp b/src/include/migraphx/op/quant_dot.hpp index 1cc9acc70dc..6289adae534 100644 --- a/src/include/migraphx/op/quant_dot.hpp +++ b/src/include/migraphx/op/quant_dot.hpp @@ -44,10 +44,9 @@ struct quant_dot const shape& a = inputs.at(0); const shape& b = inputs.at(1); auto t = a.type(); - std::set suppported_types = {shape::int8_type, shape::fp8e4m3fnuz_type}; - if(not contains(suppported_types, t)) + if(t != shape::int8_type) { - MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type"); + MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t"); } if(not std::all_of( @@ -74,10 +73,6 @@ struct quant_dot auto out_lens = a.lens(); out_lens[dim_1] = b.lens()[dim_1]; - if(t == shape::fp8e4m3fnuz_type) - { - return {shape::float_type, out_lens}; - } // else int8 gemm return {shape::int32_type, out_lens}; } }; diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 7b5479cf522..0dc093026a3 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -183,11 +183,6 @@ struct find_nested_convert auto x = ins->inputs().front(); auto input = x->inputs().front(); - while(input->name() == "convert") - { - input = input->inputs().front(); - } - if(ins->get_shape() != input->get_shape()) return; diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index 4b0ad4c7d51..7043985573b 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -69,8 +69,7 @@ struct ck_gemm static bool is_ck_supported_type(shape::type_t t) { - return contains( - {shape::half_type, shape::int8_type, shape::int32_type, shape::fp8e4m3fnuz_type}, t); + return contains({shape::half_type, shape::int8_type, shape::int32_type}, t); } }; MIGRAPHX_REGISTER_OP(ck_gemm); diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 3c21602d89a..057e81395af 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -180,9 +180,12 @@ struct gemm_impl ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc; arg_type = get_type(input_shapes[0].type()); - output_type = get_type(input_shapes[2].type()); - compute_type = - output_type; // not valid for ex3 BETA APIs. it has different type and set differently. + output_type = arg_type; + if(output_type == rocblas_datatype_i8_r) + { + output_type = rocblas_datatype_i32_r; + } + compute_type = output_type; if(compute_fp32) { if(arg_type == rocblas_datatype_f16_r) diff --git a/src/targets/gpu/include/migraphx/gpu/gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm.hpp index 056fc175c1b..bd9b0eefa14 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm.hpp @@ -112,7 +112,7 @@ struct rocblas_gemm argument compute(context& ctx, const shape& output_shape, const std::vector& args) const { - if(this->name() == "gpu::gemm" or output_shape.type() == migraphx::shape::float_type) + if(this->name() == "gpu::gemm") { gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx); } diff --git a/src/targets/ref/lowering.cpp b/src/targets/ref/lowering.cpp index 12026e9cd9b..eb1e20fe369 100644 --- a/src/targets/ref/lowering.cpp +++ b/src/targets/ref/lowering.cpp @@ -24,7 +24,6 @@ #include #include -#include #include #include #include @@ -308,46 +307,19 @@ struct ref_quant_gemm { argument result{output_shape}; // first, convert the args[0] and args[1] from int8_t to int32_t - argument arg_0{{output_shape.type(), {args.at(0).get_shape().lens()}}}; - argument arg_1{{output_shape.type(), {args.at(1).get_shape().lens()}}}; - if(output_shape.type() == migraphx::shape::float_type) - { - arg_0.visit([&](auto output) { - args.at(0).visit([&](auto input) { - std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { - return static_cast(x); - }); - }); - }); + argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}}; + argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}}; + arg_0.visit([&](auto output) { + args.at(0).visit( + [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); + }); - arg_1.visit([&](auto output) { - args.at(1).visit([&](auto input) { - std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { - return static_cast(x); - }); - }); - }); - migemm(result, arg_0, arg_1, 1.0f, 0.0f); - } - else if(output_shape.type() == migraphx::shape::int32_type) - { - arg_0.visit([&](auto output) { - args.at(0).visit([&](auto input) { - std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { - return static_cast(x); - }); - }); - }); + arg_1.visit([&](auto output) { + args.at(1).visit( + [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); + }); - arg_1.visit([&](auto output) { - args.at(1).visit([&](auto input) { - std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) { - return static_cast(x); - }); - }); - }); - migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0}); - } + migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0}); return result; } From 4e07dfcc5bde2a58e704a4a388eb5a2148c22613 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 14:30:19 +0000 Subject: [PATCH 089/115] revert some changes --- test/verify/batch_quant_dot_1.cpp | 18 +++++------------- test/verify/batch_quant_dot_2.cpp | 11 ++++------- test/verify/batch_quant_dot_3.cpp | 9 +++------ test/verify/batch_quant_dot_4.cpp | 9 +++------ test/verify/batch_quant_dot_5.cpp | 9 +++------ test/verify/quant_dot_3args_1.cpp | 18 +++++------------- test/verify/quant_dot_3args_2.cpp | 17 +++++------------ test/verify/quant_dot_3args_3.cpp | 16 +++++----------- test/verify/quant_dot_3args_4.cpp | 17 +++++------------ test/verify/quant_dot_3args_5.cpp | 14 ++++---------- 10 files changed, 42 insertions(+), 96 deletions(-) diff --git a/test/verify/batch_quant_dot_1.cpp b/test/verify/batch_quant_dot_1.cpp index 3a7e488a40a..28a10f81287 100644 --- a/test/verify/batch_quant_dot_1.cpp +++ b/test/verify/batch_quant_dot_1.cpp @@ -24,23 +24,19 @@ #include "verify_program.hpp" #include -#include #include #include #include -template -struct batch_quant_dot_1 : verify_program> +struct batch_quant_dot_1 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto dtype = migraphx::shape::get_type{}; - auto ctype = migraphx::shape::get_type{}; - migraphx::shape m1_shape{dtype, {3, 2, 8, 2}}; - migraphx::shape m2_shape{dtype, {3, 2, 7, 8}}; - migraphx::shape m3_shape{ctype, {3, 2, 2, 7}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto tl1 = mm->add_instruction( @@ -49,11 +45,7 @@ struct batch_quant_dot_1 : verify_program> auto tl2 = mm->add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta( - *mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2}); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); return p; } }; - -template struct batch_quant_dot_1; -template struct batch_quant_dot_1; diff --git a/test/verify/batch_quant_dot_2.cpp b/test/verify/batch_quant_dot_2.cpp index 3a1b2004f16..241cac71a39 100644 --- a/test/verify/batch_quant_dot_2.cpp +++ b/test/verify/batch_quant_dot_2.cpp @@ -28,16 +28,15 @@ #include #include -template -struct batch_quant_dot_2 : verify_program> +struct batch_quant_dot_2 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{DType, {3, 2, 2, 8}}; - migraphx::shape m2_shape{DType, {3, 2, 8, 7}}; - migraphx::shape m3_shape{CType, {3, 2, 2, 7}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); @@ -46,5 +45,3 @@ struct batch_quant_dot_2 : verify_program> return p; } }; -template struct batch_quant_dot_2; -template struct batch_quant_dot_2; diff --git a/test/verify/batch_quant_dot_3.cpp b/test/verify/batch_quant_dot_3.cpp index 8c861db01dc..05bcc1420f6 100644 --- a/test/verify/batch_quant_dot_3.cpp +++ b/test/verify/batch_quant_dot_3.cpp @@ -27,15 +27,14 @@ #include #include -template -struct batch_quant_dot_3 : verify_program> +struct batch_quant_dot_3 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{DType, {3, 2, 2, 6}}; - migraphx::shape m2_shape{DType, {3, 2, 6, 7}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); @@ -43,5 +42,3 @@ struct batch_quant_dot_3 : verify_program> return p; } }; -template struct batch_quant_dot_3; -template struct batch_quant_dot_3; diff --git a/test/verify/batch_quant_dot_4.cpp b/test/verify/batch_quant_dot_4.cpp index 230a1988321..7865b9e46e8 100644 --- a/test/verify/batch_quant_dot_4.cpp +++ b/test/verify/batch_quant_dot_4.cpp @@ -27,15 +27,14 @@ #include #include -template -struct batch_quant_dot_4 : verify_program> +struct batch_quant_dot_4 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{DType, {2, 4, 6, 3}}; - migraphx::shape m2_shape{DType, {7, 2, 6, 3}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); @@ -47,5 +46,3 @@ struct batch_quant_dot_4 : verify_program> return p; } }; -template struct batch_quant_dot_4; -template struct batch_quant_dot_4; diff --git a/test/verify/batch_quant_dot_5.cpp b/test/verify/batch_quant_dot_5.cpp index 78426615c73..5f5ba073183 100644 --- a/test/verify/batch_quant_dot_5.cpp +++ b/test/verify/batch_quant_dot_5.cpp @@ -27,15 +27,14 @@ #include #include -template -struct batch_quant_dot_5 : verify_program> +struct batch_quant_dot_5 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{DType, {3, 2, 7, 2}}; - migraphx::shape m2_shape{DType, {3, 2, 5, 7}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); @@ -49,5 +48,3 @@ struct batch_quant_dot_5 : verify_program> return p; } }; -template struct batch_quant_dot_5; -template struct batch_quant_dot_5; diff --git a/test/verify/quant_dot_3args_1.cpp b/test/verify/quant_dot_3args_1.cpp index ab45e9ece72..c233e4a22bb 100644 --- a/test/verify/quant_dot_3args_1.cpp +++ b/test/verify/quant_dot_3args_1.cpp @@ -25,31 +25,23 @@ #include "verify_program.hpp" #include #include -#include #include #include -template -struct quant_dot_3args_1 : verify_program> +struct quant_dot_3args_1 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto ctype = migraphx::shape::get_type(); - auto dtype = migraphx::shape::get_type(); - migraphx::shape m1_shape{dtype, {2, 8}}; - migraphx::shape m2_shape{dtype, {8, 7}}; - migraphx::shape m3_shape{ctype, {2, 7}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta( - *mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{1}); + migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1); return p; } }; - -template struct quant_dot_3args_1; -template struct quant_dot_3args_1; diff --git a/test/verify/quant_dot_3args_2.cpp b/test/verify/quant_dot_3args_2.cpp index 5037960bb61..b546e5194e8 100644 --- a/test/verify/quant_dot_3args_2.cpp +++ b/test/verify/quant_dot_3args_2.cpp @@ -28,29 +28,22 @@ #include #include -template -struct quant_dot_3args_2 : verify_program> +struct quant_dot_3args_2 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto ctype = migraphx::shape::get_type(); - auto dtype = migraphx::shape::get_type(); - migraphx::shape m1_shape{dtype, {8, 2}}; - migraphx::shape m2_shape{dtype, {8, 7}}; - migraphx::shape m3_shape{ctype, {2, 7}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); auto l2 = mm->add_parameter("b", m2_shape); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta( - *mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3}); + migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); return p; } }; - -template struct quant_dot_3args_2; -template struct quant_dot_3args_2; diff --git a/test/verify/quant_dot_3args_3.cpp b/test/verify/quant_dot_3args_3.cpp index 2c0bcd3f2ea..12ba110eb96 100644 --- a/test/verify/quant_dot_3args_3.cpp +++ b/test/verify/quant_dot_3args_3.cpp @@ -28,28 +28,22 @@ #include #include -template -struct quant_dot_3args_3 : verify_program> +struct quant_dot_3args_3 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto ctype = migraphx::shape::get_type(); - auto dtype = migraphx::shape::get_type(); - migraphx::shape m1_shape{dtype, {2, 8}}; - migraphx::shape m2_shape{dtype, {7, 8}}; - migraphx::shape m3_shape{ctype, {2, 7}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto l2 = mm->add_parameter("b", m2_shape); auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta( - *mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), CType{2}, CType{3}); + migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3); return p; } }; -template struct quant_dot_3args_3; -template struct quant_dot_3args_3; diff --git a/test/verify/quant_dot_3args_4.cpp b/test/verify/quant_dot_3args_4.cpp index 9872c76d00d..cc559be70b4 100644 --- a/test/verify/quant_dot_3args_4.cpp +++ b/test/verify/quant_dot_3args_4.cpp @@ -28,18 +28,15 @@ #include #include -template -struct quant_dot_3args_4 : verify_program> +struct quant_dot_3args_4 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto ctype = migraphx::shape::get_type(); - auto dtype = migraphx::shape::get_type(); - migraphx::shape m1_shape{dtype, {8, 2}}; - migraphx::shape m2_shape{dtype, {7, 8}}; - migraphx::shape m3_shape{ctype, {2, 7}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; + migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; auto l1 = mm->add_parameter("a", m1_shape); auto tl1 = @@ -48,11 +45,7 @@ struct quant_dot_3args_4 : verify_program> auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); auto l3 = mm->add_parameter("c", m3_shape); - migraphx::add_apply_alpha_beta( - *mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2}); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); return p; } }; - -template struct quant_dot_3args_4; -template struct quant_dot_3args_4; diff --git a/test/verify/quant_dot_3args_5.cpp b/test/verify/quant_dot_3args_5.cpp index 7d3926981ea..120487e93f3 100644 --- a/test/verify/quant_dot_3args_5.cpp +++ b/test/verify/quant_dot_3args_5.cpp @@ -28,17 +28,14 @@ #include #include -template -struct quant_dot_3args_5 : verify_program> +struct quant_dot_3args_5 : verify_program { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - auto dtype = migraphx::shape::get_type(); - - migraphx::shape m1_shape{dtype, {6, 2}}; - migraphx::shape m2_shape{dtype, {7, 6}}; + migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}}; + migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}}; auto l1 = mm->add_parameter("a", m1_shape); auto tl1 = @@ -46,10 +43,7 @@ struct quant_dot_3args_5 : verify_program> auto l2 = mm->add_parameter("b", m2_shape); auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); - migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), CType{3}); + migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3); return p; } }; - -template struct quant_dot_3args_5; -template struct quant_dot_3args_5; From 370d18c187385e742459351385982ba80d64a67e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 15:19:12 +0000 Subject: [PATCH 090/115] add quant_conv tests --- src/include/migraphx/op/quant_convolution.hpp | 16 +++++++++++----- src/targets/gpu/fuse_mlir.cpp | 8 +++++--- test/verify/quant_conv.cpp | 10 +++++++--- test/verify/quant_conv_1.cpp | 10 +++++++--- test/verify/quant_conv_1d.cpp | 11 ++++++++--- test/verify/quant_conv_2.cpp | 12 +++++++++--- test/verify/quant_conv_padding.cpp | 10 +++++++--- test/verify/quant_conv_padding_stride.cpp | 9 ++++++--- 8 files changed, 60 insertions(+), 26 deletions(-) diff --git a/src/include/migraphx/op/quant_convolution.hpp b/src/include/migraphx/op/quant_convolution.hpp index fb20eff6b74..dec50315f7a 100644 --- a/src/include/migraphx/op/quant_convolution.hpp +++ b/src/include/migraphx/op/quant_convolution.hpp @@ -24,6 +24,7 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP #define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP +#include "migraphx/shape.hpp" #include #include #include @@ -87,11 +88,13 @@ struct quant_convolution } // all input type must be int8_type and output is float_type - if(t != shape::int8_type) + std::set supported_types = {shape::int8_type, + shape::fp8e4m3fnuz_type}; + if(not contains(supported_types, t)) { - MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t"); + MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t or " + "fp8e4m3fnuz_type"); } - t = shape::int32_type; std::vector output_lens{input.lens()[0], weights.lens()[0]}; auto padding_size = padding.size(); @@ -107,8 +110,11 @@ struct quant_convolution stride[i] + 1))); } - - return inputs[0].with_lens(t, output_lens); + if(t == shape::int8_type) + { + return inputs[0].with_lens(shape::int32_type, output_lens); + } // else fp8 conv + return inputs[0].with_lens(shape::float_type, output_lens); } size_t kdims() const diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index d407c249e9f..52726f1bdc0 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -214,6 +214,7 @@ auto is_mlir_conv(mlir_mode mode) return false; if(ins->name() != "convolution" and ins->name() != "quant_convolution") return false; + auto input_arg_t = ins->inputs().front()->get_shape().type(); value v = ins->get_operator().to_value(); auto group = v.at("group").to(); if(group != 1) @@ -223,6 +224,8 @@ auto is_mlir_conv(mlir_mode mode) return false; if(ins->get_shape().type() == shape::fp8e4m3fnuz_type) return true; + if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type) + return true; if(ins->get_shape().type() == shape::int8_type) return true; if(mode == mlir_mode::int8) @@ -403,8 +406,7 @@ struct find_mlir_standalone_op void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto gemm_based_op = r.result; - // - // enable only for fp32/fp16/i8 types + // enable only for fp32/fp16/i8/fp8 types if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { return not contains( {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type, shape::type_t::fp8e4m3fnuz_type}, @@ -530,7 +532,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const match::find_matches( mpm, - find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)}, + find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::all)}, find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)}); #else (void)mpm; diff --git a/test/verify/quant_conv.cpp b/test/verify/quant_conv.cpp index 72f32f453f3..616b38b0e04 100644 --- a/test/verify/quant_conv.cpp +++ b/test/verify/quant_conv.cpp @@ -27,17 +27,21 @@ #include #include -struct quant_conv : verify_program +template +struct quant_conv : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc); return p; } }; + +template struct quant_conv; +template struct quant_conv; diff --git a/test/verify/quant_conv_1.cpp b/test/verify/quant_conv_1.cpp index 928badbd7cb..a13bd5ce3f9 100644 --- a/test/verify/quant_conv_1.cpp +++ b/test/verify/quant_conv_1.cpp @@ -27,17 +27,21 @@ #include #include -struct quant_conv_1 : verify_program +template +struct quant_conv_1 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); return p; } }; + +template struct quant_conv_1; +template struct quant_conv_1; diff --git a/test/verify/quant_conv_1d.cpp b/test/verify/quant_conv_1d.cpp index 2648134c4e3..c8a08cf92ab 100644 --- a/test/verify/quant_conv_1d.cpp +++ b/test/verify/quant_conv_1d.cpp @@ -27,15 +27,16 @@ #include #include -struct quant_conv_1d : verify_program +template +struct quant_conv_1d : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction( migraphx::make_op("quant_convolution", @@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program return p; } }; + +template struct quant_conv_1d; +// MLIR 1D convolution is not supported in MIGraphX yet. +// template struct quant_conv_1d; diff --git a/test/verify/quant_conv_2.cpp b/test/verify/quant_conv_2.cpp index 9ae561f732b..8ea6bb7ebff 100644 --- a/test/verify/quant_conv_2.cpp +++ b/test/verify/quant_conv_2.cpp @@ -27,17 +27,23 @@ #include #include -struct quant_conv_2 : verify_program +template +struct quant_conv_2 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}}; + migraphx::shape a_shape{DType, {16, 16, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}}; + migraphx::shape c_shape{DType, {16, 16, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); return p; } }; + +template struct quant_conv_2; +template struct quant_conv_2; + + diff --git a/test/verify/quant_conv_padding.cpp b/test/verify/quant_conv_padding.cpp index f566c314f4c..29159ef7f81 100644 --- a/test/verify/quant_conv_padding.cpp +++ b/test/verify/quant_conv_padding.cpp @@ -27,15 +27,16 @@ #include #include -struct quant_conv_padding : verify_program +template +struct quant_conv_padding : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction( migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), @@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program return p; } }; + +template struct quant_conv_padding; +template struct quant_conv_padding; diff --git a/test/verify/quant_conv_padding_stride.cpp b/test/verify/quant_conv_padding_stride.cpp index f1c07399fc0..955a3b23352 100644 --- a/test/verify/quant_conv_padding_stride.cpp +++ b/test/verify/quant_conv_padding_stride.cpp @@ -27,15 +27,16 @@ #include #include -struct quant_conv_padding_stride : verify_program +template +struct quant_conv_padding_stride : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; + migraphx::shape a_shape{DType, {2, 3, 4, 4}}; auto pa = mm->add_parameter("a", a_shape); - migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; + migraphx::shape c_shape{DType, {2, 3, 3, 3}}; auto pc = mm->add_parameter("c", c_shape); mm->add_instruction( migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), @@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program return p; } }; +template struct quant_conv_padding_stride; +template struct quant_conv_padding_stride; From 24c63d708cded3f1778b48bbd8c08a4f0f2181d6 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 15:19:22 +0000 Subject: [PATCH 091/115] add comment for 1d convs --- test/verify/quant_conv_1d.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/quant_conv_1d.cpp b/test/verify/quant_conv_1d.cpp index c8a08cf92ab..069cc4efd22 100644 --- a/test/verify/quant_conv_1d.cpp +++ b/test/verify/quant_conv_1d.cpp @@ -48,5 +48,5 @@ struct quant_conv_1d : verify_program> }; template struct quant_conv_1d; -// MLIR 1D convolution is not supported in MIGraphX yet. +// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later. // template struct quant_conv_1d; From c522d47aed672f033ca765402a84f510e72b18ad Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 15:21:49 +0000 Subject: [PATCH 092/115] I dont' know why this test was disabled for the PGpu but enabling it since it passes. --- test/verify/main.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/verify/main.cpp b/test/verify/main.cpp index 9a7f8481226..23404864533 100644 --- a/test/verify/main.cpp +++ b/test/verify/main.cpp @@ -77,6 +77,5 @@ int main(int argc, const char* argv[]) "test_split_single_dyn_dim", "test_instancenorm_large_3d", "test_instancenorm_large_3d"}); - rv.disable_test_for("gpu", {"test_conv_bn_add"}); rv.run(argc, argv); } From fe585d428bbbd5836024cc157b836f3fb33749a2 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 15:27:49 +0000 Subject: [PATCH 093/115] Disable FP8 tests for the non-gfx940 arches --- src/targets/gpu/target.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 0610a745900..8d1c0d8205a 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -112,6 +112,11 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_fp8_ops.insert("dot"); } unsupported_fp8_ops.insert("pooling"); + if(not starts_with(gpu::get_device_name(), "gfx94")) + { + unsupported_fp8_ops.insert("conv"); + unsupported_fp8_ops.insert("quant_conv"); + } // clang-format off return { From 994d24b6457a831c3e17bf46c1da93aafaf410da Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 15:34:20 +0000 Subject: [PATCH 094/115] use helper function to determine gfx940 --- src/targets/gpu/device_name.cpp | 6 ++++++ src/targets/gpu/include/migraphx/gpu/device_name.hpp | 2 ++ src/targets/gpu/rocblas.cpp | 3 +-- src/targets/gpu/target.cpp | 4 +++- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/device_name.cpp b/src/targets/gpu/device_name.cpp index ac38d6e8057..e65b97622f6 100644 --- a/src/targets/gpu/device_name.cpp +++ b/src/targets/gpu/device_name.cpp @@ -49,6 +49,12 @@ std::string get_device_name() return props.gcnArchName; } +bool gfx_has_fp8_intrinsics() +{ + const auto device_name = trim(split_string(get_device_name(), ':').front()); + return (starts_with(device_name, "gfx9") and device_name >= "gfx940"); +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/include/migraphx/gpu/device_name.hpp b/src/targets/gpu/include/migraphx/gpu/device_name.hpp index 44312d1f845..54ea873feea 100644 --- a/src/targets/gpu/include/migraphx/gpu/device_name.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device_name.hpp @@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name(); MIGRAPHX_GPU_EXPORT int get_device_id(); +MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics(); + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/rocblas.cpp b/src/targets/gpu/rocblas.cpp index 59452408801..9697189d921 100644 --- a/src/targets/gpu/rocblas.cpp +++ b/src/targets/gpu/rocblas.cpp @@ -58,8 +58,7 @@ bool rocblas_fp8_available() #ifndef MIGRAPHX_USE_ROCBLAS_FP8_API return false; #else - const auto device_name = trim(split_string(get_device_name(), ':').front()); - return (starts_with(device_name, "gfx9") and device_name >= "gfx940"); + return gfx_has_fp8_intrinsics(); #endif } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 8d1c0d8205a..00060d6ad11 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -106,13 +106,15 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::tuple_type); + // whiltelist supported Ops for the FP8 std::set unsupported_fp8_ops = {}; if(not gpu::rocblas_fp8_available()) { unsupported_fp8_ops.insert("dot"); } + // MIOpen doesn't have support for fp8 pooling yet. unsupported_fp8_ops.insert("pooling"); - if(not starts_with(gpu::get_device_name(), "gfx94")) + if(not gpu::gfx_has_fp8_intrinsics()) { unsupported_fp8_ops.insert("conv"); unsupported_fp8_ops.insert("quant_conv"); From 51ac4fdd586756ec144ce3a48219daf2098a4f77 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 16:00:59 +0000 Subject: [PATCH 095/115] fix naming --- src/targets/gpu/target.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 00060d6ad11..0a40f495853 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -116,8 +116,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_fp8_ops.insert("pooling"); if(not gpu::gfx_has_fp8_intrinsics()) { - unsupported_fp8_ops.insert("conv"); - unsupported_fp8_ops.insert("quant_conv"); + unsupported_fp8_ops.insert("convolution"); + unsupported_fp8_ops.insert("quant_convolution"); } // clang-format off return From d06dd8ddbb7e3589db43e08f2694688f0399130e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 16:14:50 +0000 Subject: [PATCH 096/115] use generale_type --- src/eliminate_fp8.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/eliminate_fp8.cpp b/src/eliminate_fp8.cpp index 9a43253b94b..e84fb75299a 100644 --- a/src/eliminate_fp8.cpp +++ b/src/eliminate_fp8.cpp @@ -39,8 +39,7 @@ void eliminate_fp8::apply(module& m) const { for(auto ins : iterator_for(m)) { - if(not contains(op_names, ins->name()) or - ins->get_shape().type() != migraphx::shape::fp8e4m3fnuz_type) + if(not contains(op_names, ins->name())) continue; migraphx::shape::type_t orig_type = ins->get_shape().type(); std::vector orig_inputs = ins->inputs(); @@ -55,8 +54,13 @@ void eliminate_fp8::apply(module& m) const "convert", {{"target_type", migraphx::to_value(target_type)}}), i); }); - - auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs}); + auto op = ins->get_operator(); + auto attributes = op.attributes(); + if(attributes.contains("general_data_type")) + { + op = make_op(attributes["general_data_type"].to(), op.to_value()); + } + auto new_ins = m.insert_instruction(ins, op, {new_inputs}); auto convert_back_ins = m.insert_instruction( ins, migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}), From 40e7698a3aac4079033937f4b385eba32fc97065 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 16:27:37 +0000 Subject: [PATCH 097/115] do not use brackets --- src/eliminate_fp8.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eliminate_fp8.cpp b/src/eliminate_fp8.cpp index e84fb75299a..4de035bf67e 100644 --- a/src/eliminate_fp8.cpp +++ b/src/eliminate_fp8.cpp @@ -60,7 +60,7 @@ void eliminate_fp8::apply(module& m) const { op = make_op(attributes["general_data_type"].to(), op.to_value()); } - auto new_ins = m.insert_instruction(ins, op, {new_inputs}); + auto new_ins = m.insert_instruction(ins, op, new_inputs); auto convert_back_ins = m.insert_instruction( ins, migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}), From 119a6b8653bd6c583aa90158c1768f7b17efdc06 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 16:42:34 +0000 Subject: [PATCH 098/115] Try removing fusing converts --- src/targets/gpu/fuse_mlir.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 52726f1bdc0..f600f47e703 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -344,6 +344,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) // supported. if(is_float and name == "convert") { + if(result_type == shape::fp8e4m3fnuz_type) + { + return false; + } // else return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); }); From fc093b02e869e4c1d334f499ded27171816f85a2 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 17:05:23 +0000 Subject: [PATCH 099/115] formatting --- src/targets/gpu/fuse_mlir.cpp | 11 +++++++---- test/verify/quant_conv_2.cpp | 2 -- test/verify/test_conv.cpp | 2 +- test/verify/test_conv2.cpp | 6 ++---- test/verify/test_conv_add.cpp | 6 +++--- test/verify/test_conv_add_relu.cpp | 10 ++++------ test/verify/test_conv_bn.cpp | 2 +- test/verify/test_conv_bn_add.cpp | 2 +- test/verify/test_conv_pooling.cpp | 6 ++---- test/verify/test_conv_relu.cpp | 2 +- 10 files changed, 22 insertions(+), 27 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f600f47e703..756a09432ff 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -333,7 +333,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) "softmax", "tanh", }; - bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type); + bool is_float = + contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type); if(contains(any_type_ops, name)) return true; if(result_type != type_t::bool_type and contains(no_bool_ops, name)) @@ -412,9 +413,11 @@ struct find_mlir_standalone_op auto gemm_based_op = r.result; // enable only for fp32/fp16/i8/fp8 types if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { - return not contains( - {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type, shape::type_t::fp8e4m3fnuz_type}, - i->get_shape().type()); + return not contains({shape::type_t::float_type, + shape::type_t::half_type, + shape::type_t::int8_type, + shape::type_t::fp8e4m3fnuz_type}, + i->get_shape().type()); })) return; static size_t counter = 0; diff --git a/test/verify/quant_conv_2.cpp b/test/verify/quant_conv_2.cpp index 8ea6bb7ebff..1873852fee5 100644 --- a/test/verify/quant_conv_2.cpp +++ b/test/verify/quant_conv_2.cpp @@ -45,5 +45,3 @@ struct quant_conv_2 : verify_program> template struct quant_conv_2; template struct quant_conv_2; - - diff --git a/test/verify/test_conv.cpp b/test/verify/test_conv.cpp index 118048f3f81..9b5d0caef7d 100644 --- a/test/verify/test_conv.cpp +++ b/test/verify/test_conv.cpp @@ -34,7 +34,7 @@ struct test_conv : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); mm->add_instruction(migraphx::make_op("convolution"), input, weights); return p; diff --git a/test/verify/test_conv2.cpp b/test/verify/test_conv2.cpp index a3d2a123868..bbdf9d1a1c2 100644 --- a/test/verify/test_conv2.cpp +++ b/test/verify/test_conv2.cpp @@ -34,10 +34,8 @@ struct test_conv2 : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}}); - auto weights = - mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}}); mm->add_instruction( migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), diff --git a/test/verify/test_conv_add.cpp b/test/verify/test_conv_add.cpp index 751c5798156..d97a4c9652f 100644 --- a/test/verify/test_conv_add.cpp +++ b/test/verify/test_conv_add.cpp @@ -34,9 +34,9 @@ struct test_conv_add : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}}); - auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1)); - auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}}); + auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}}); + auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1)); + auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}}); auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2)); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v); diff --git a/test/verify/test_conv_add_relu.cpp b/test/verify/test_conv_add_relu.cpp index 69e60792ad0..2611e2f99d4 100644 --- a/test/verify/test_conv_add_relu.cpp +++ b/test/verify/test_conv_add_relu.cpp @@ -35,12 +35,10 @@ struct test_conv_add_relu : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); - auto weights = - mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); - auto bias_literal = migraphx::literal{migraphx::shape{DType, {4}}, - {2.0f, 2.0f, 2.0f, 2.0f}}; + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto bias_literal = + migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; auto bias = mm->add_literal(bias_literal); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto bcast_bias = mm->add_instruction( diff --git a/test/verify/test_conv_bn.cpp b/test/verify/test_conv_bn.cpp index 150ee68c0df..5f356636efe 100644 --- a/test/verify/test_conv_bn.cpp +++ b/test/verify/test_conv_bn.cpp @@ -54,7 +54,7 @@ struct test_conv_bn : verify_program> auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); - auto rt = mm->add_literal(migraphx::literal{DType, {0.5}}); + auto rt = mm->add_literal(migraphx::literal{DType, {0.5}}); auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) diff --git a/test/verify/test_conv_bn_add.cpp b/test/verify/test_conv_bn_add.cpp index 99319f44b28..52a8486456a 100644 --- a/test/verify/test_conv_bn_add.cpp +++ b/test/verify/test_conv_bn_add.cpp @@ -74,7 +74,7 @@ struct test_conv_bn_add : verify_program> auto x = mm->add_parameter("x", {DType, {1, ichannels, 56, 56}}); auto w = mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 1)); - auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}}); + auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}}); auto v = mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 2)); auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x); diff --git a/test/verify/test_conv_pooling.cpp b/test/verify/test_conv_pooling.cpp index d12c81a21ef..4fbe8f17c65 100644 --- a/test/verify/test_conv_pooling.cpp +++ b/test/verify/test_conv_pooling.cpp @@ -35,10 +35,8 @@ struct test_conv_pooling : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - auto input = - mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}}); - auto weights = - mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}}); + auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto pooling = mm->add_instruction( migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv); diff --git a/test/verify/test_conv_relu.cpp b/test/verify/test_conv_relu.cpp index 2d41deed401..aa9af88bf01 100644 --- a/test/verify/test_conv_relu.cpp +++ b/test/verify/test_conv_relu.cpp @@ -34,7 +34,7 @@ struct test_conv_relu : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); mm->add_instruction(migraphx::make_op("relu"), conv); From b6a436fe6532bd316ec95b0f7b572882ca3eee35 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 3 Dec 2023 17:07:20 +0000 Subject: [PATCH 100/115] update MLIR commit hasH --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e3e1fa8ac3c..7f7a0039fef 100755 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCmSoftwarePlatform/rocMLIR@9e66e8050209f03349a41b6b497f0da2b285a53b -DBUILD_FAT_LIBROCKCOMPILER=On +ROCmSoftwarePlatform/rocMLIR@5085343bca363109ae9ebabb7ca2b65c52bc861c -DBUILD_FAT_LIBROCKCOMPILER=On From 5423577a5b84a6f86f3e4fa6c3fa22dde00cb40a Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 4 Dec 2023 22:01:34 +0000 Subject: [PATCH 101/115] use updated eliminate_fp8 pass --- src/eliminate_fp8.cpp | 82 ++++++++++++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/src/eliminate_fp8.cpp b/src/eliminate_fp8.cpp index 9a43253b94b..d8aa046e792 100644 --- a/src/eliminate_fp8.cpp +++ b/src/eliminate_fp8.cpp @@ -24,8 +24,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -39,29 +41,67 @@ void eliminate_fp8::apply(module& m) const { for(auto ins : iterator_for(m)) { - if(not contains(op_names, ins->name()) or - ins->get_shape().type() != migraphx::shape::fp8e4m3fnuz_type) + if(not contains(op_names, ins->name())) continue; - migraphx::shape::type_t orig_type = ins->get_shape().type(); - std::vector orig_inputs = ins->inputs(); - std::vector new_inputs; - std::transform(orig_inputs.begin(), - orig_inputs.end(), - std::back_inserter(new_inputs), - [&](const auto& i) { - return m.insert_instruction( - ins, - migraphx::make_op( - "convert", {{"target_type", migraphx::to_value(target_type)}}), - i); - }); + migraphx::shape::type_t orig_type = ins->get_shape().type(); + std::vector inputs = ins->inputs(); + migraphx::transform_if( + inputs.begin(), + inputs.end(), + inputs.begin(), + [&](const auto& i) { return i->get_shape().type() == shape::fp8e4m3fnuz_type; }, + [&](const auto& i) { + return m.insert_instruction( + ins, + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(target_type)}}), + i); + }); + if(inputs == ins->inputs()) + { + return; + } + auto op = ins->get_operator(); + auto attributes = op.attributes(); + if(attributes.contains("general_data_type")) + { + op = make_op(attributes["general_data_type"].to(), op.to_value()); + } + auto new_ins = m.insert_instruction(ins, op, inputs); + if(orig_type == shape::tuple_type) + { + auto orig_outs = ins->outputs(); + if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) { + return out_ins->name() == "get_tuple_elem"; + })) + MIGRAPHX_THROW("EliminateFP8: Instruction with tuple output doesn't have all its " + "usages as get_tuple_elem instruction"); - auto new_ins = m.insert_instruction(ins, ins->get_operator(), {new_inputs}); - auto convert_back_ins = m.insert_instruction( - ins, - migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}), - new_ins); - m.replace_instruction(ins, convert_back_ins); + std::transform( + orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) { + auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins); + if(out_ins->get_shape().type() == shape::type_t::fp8e4m3fnuz_type) + { + auto gte_convert = m.insert_instruction( + ins, + make_op("convert", {{"target_type", shape::type_t::fp8e4m3fnuz_type}}), + gte_ins); + return m.replace_instruction(out_ins, gte_convert); + } + else + { + return m.replace_instruction(out_ins, gte_ins); + } + }); + } + else + { + auto convert_back_ins = m.insert_instruction( + ins, + migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}), + new_ins); + m.replace_instruction(ins, convert_back_ins); + } } } From 402c66ab683c1bcca843276361524427810ca7d7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 00:34:12 +0000 Subject: [PATCH 102/115] use eliminate_data_type pass instead of eliminate_fp8 pass --- src/eliminate_data_type.cpp | 115 +++++++++++++++---- src/include/migraphx/eliminate_data_type.hpp | 3 +- src/targets/cpu/target.cpp | 2 +- src/targets/gpu/target.cpp | 5 +- test/eliminate_data_type_test.cpp | 14 ++- 5 files changed, 107 insertions(+), 32 deletions(-) diff --git a/src/eliminate_data_type.cpp b/src/eliminate_data_type.cpp index 87b7a39b414..1c8efce216f 100644 --- a/src/eliminate_data_type.cpp +++ b/src/eliminate_data_type.cpp @@ -31,6 +31,72 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +void insert_convert_to_supported_type(module& m, + instruction_ref ins, + migraphx::shape::type_t target_type, + std::set unsupported_types) +{ + migraphx::shape::type_t orig_type = ins->get_shape().type(); + std::vector inputs = ins->inputs(); + std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](const auto& i) { + if(contains(unsupported_types, i->get_shape().type())) + { + return m.insert_instruction( + ins, + migraphx::make_op("convert", {{"target_type", migraphx::to_value(target_type)}}), + i); + } + else + { + return i; + } + }); + // if no change + if(inputs == ins->inputs()) + return; + auto op = ins->get_operator(); + auto attributes = op.attributes(); + if(attributes.contains("general_data_type")) + { + op = make_op(attributes["general_data_type"].to(), op.to_value()); + } + auto new_ins = m.insert_instruction(ins, op, inputs); + if(orig_type == shape::tuple_type) + { + auto orig_outs = ins->outputs(); + if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) { + return out_ins->name() == "get_tuple_elem"; + })) + MIGRAPHX_THROW( + "eliminate_data_type: Instruction with tuple output doesn't have all its " + "usages as get_tuple_elem instruction"); + + std::transform( + orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) { + auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins); + auto orig_out_type = out_ins->get_shape().type(); + if(contains(unsupported_types, orig_out_type)) + { + auto gte_convert = m.insert_instruction( + ins, make_op("convert", {{"target_type", orig_out_type}}), gte_ins); + return m.replace_instruction(out_ins, gte_convert); + } + else + { + return m.replace_instruction(out_ins, gte_ins); + } + }); + } + else + { + auto convert_back_ins = m.insert_instruction( + ins, + migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}), + new_ins); + m.replace_instruction(ins, convert_back_ins); + } +} + void eliminate_data_type::apply(module& m) const { static const std::vector skip_op_names = {"convert", @@ -42,31 +108,36 @@ void eliminate_data_type::apply(module& m) const "scatternd_add", "scatternd_mul", "scatternd_none"}; - for(auto ins : iterator_for(m)) + if(unsupported_types.empty() and unsupported_types.empty()) + { + return; + } + else if(not unsupported_fp8_ops.empty() and not unsupported_types.empty()) + { + MIGRAPHX_THROW("eliminate_data_type: specify either unsupported FP8 ops or unsupported " + "data types not both."); + } + else if(unsupported_fp8_ops.empty()) + { + for(auto ins : iterator_for(m)) + { + if(ins->name()[0] == '@') + continue; + if(contains(skip_op_names, ins->name())) + continue; + insert_convert_to_supported_type(m, ins, target_type, unsupported_types); + } + } + else { - if(ins->name()[0] == '@') - continue; - if(contains(skip_op_names, ins->name())) - continue; - auto inputs = ins->inputs(); - std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) { - if(types.count(i->get_shape().type()) == 0) - return i; - return m.insert_instruction(ins, make_op("convert", {{"target_type", target_type}}), i); - }); - if(inputs == ins->inputs()) - continue; - auto op = ins->get_operator(); - auto attributes = op.attributes(); - if(attributes.contains("general_data_type")) + std::set unsupported_fp8_types = { + migraphx::shape::fp8e4m3fnuz_type}; + for(auto ins : iterator_for(m)) { - op = make_op(attributes["general_data_type"].to(), op.to_value()); + if(not contains(unsupported_fp8_ops, ins->name())) + continue; + insert_convert_to_supported_type(m, ins, target_type, unsupported_fp8_types); } - auto old_type = ins->get_shape().type(); - auto out = m.insert_instruction(ins, op, inputs); - auto convert = - m.insert_instruction(ins, make_op("convert", {{"target_type", old_type}}), out); - m.replace_instruction(ins, convert); } } diff --git a/src/include/migraphx/eliminate_data_type.hpp b/src/include/migraphx/eliminate_data_type.hpp index f0ddfc6e6ad..52dc5f187f1 100644 --- a/src/include/migraphx/eliminate_data_type.hpp +++ b/src/include/migraphx/eliminate_data_type.hpp @@ -40,7 +40,8 @@ struct module; */ struct MIGRAPHX_EXPORT eliminate_data_type { - std::set types; + std::set unsupported_types; + std::set unsupported_fp8_ops; shape::type_t target_type; std::string name() const { return "eliminate_data_type"; } void apply(module& m) const; diff --git a/src/targets/cpu/target.cpp b/src/targets/cpu/target.cpp index ebec22a501d..f0836f70423 100644 --- a/src/targets/cpu/target.cpp +++ b/src/targets/cpu/target.cpp @@ -70,7 +70,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti return {normalize_ops{}, rewrite_quantization{}, dead_code_elimination{}, - eliminate_data_type{unsupported_types, shape::type_t::float_type}, + eliminate_data_type{unsupported_types, {}, shape::type_t::float_type}, dead_code_elimination{}, simplify_reshapes{}, eliminate_identity{}, diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index d6aadf52569..a25f43600d4 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/shape.hpp" #include #include #include @@ -123,7 +124,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti simplify_qdq{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, - eliminate_data_type{unsupported_types, shape::type_t::float_type}, + eliminate_data_type{unsupported_types, {}, shape::type_t::float_type}, simplify_reshapes{}, eliminate_identity{}, eliminate_pad{}, @@ -142,7 +143,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti prefuse_ops{}, dead_code_elimination{}, auto_contiguous{}, - eliminate_fp8{unsupported_fp8_ops}, + eliminate_data_type{{}, unsupported_fp8_ops, shape::float_type}, dead_code_elimination{}, optimize_module{}, fuse_pointwise{}, diff --git a/test/eliminate_data_type_test.cpp b/test/eliminate_data_type_test.cpp index 23ff9c32fec..e0a4cd0ddec 100644 --- a/test/eliminate_data_type_test.cpp +++ b/test/eliminate_data_type_test.cpp @@ -30,13 +30,15 @@ #include -void run_pass(migraphx::module& m, std::set types) +void run_pass(migraphx::module& m, + std::set types, + std::set unsupported_fp8_ops = {}) { - migraphx::run_passes( - m, - {migraphx::eliminate_data_type{std::move(types), migraphx::shape::float_type}, - migraphx::eliminate_identity{}, - migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, + {migraphx::eliminate_data_type{ + std::move(types), unsupported_fp8_ops, migraphx::shape::float_type}, + migraphx::eliminate_identity{}, + migraphx::dead_code_elimination{}}); } TEST_CASE(simple) From 4ca90ec7af433b035e81d3aae539113457cce29c Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 00:39:41 +0000 Subject: [PATCH 103/115] remove older files --- src/CMakeLists.txt | 1 - src/eliminate_fp8.cpp | 109 ------------------------- src/include/migraphx/eliminate_fp8.hpp | 52 ------------ 3 files changed, 162 deletions(-) delete mode 100644 src/eliminate_fp8.cpp delete mode 100644 src/include/migraphx/eliminate_fp8.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c03df41b8ea..8cffafb57bf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -49,7 +49,6 @@ add_library(migraphx eliminate_concat.cpp eliminate_contiguous.cpp eliminate_data_type.cpp - eliminate_fp8.cpp eliminate_identity.cpp eliminate_pad.cpp env.cpp diff --git a/src/eliminate_fp8.cpp b/src/eliminate_fp8.cpp deleted file mode 100644 index d8aa046e792..00000000000 --- a/src/eliminate_fp8.cpp +++ /dev/null @@ -1,109 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -void eliminate_fp8::apply(module& m) const -{ - for(auto ins : iterator_for(m)) - { - if(not contains(op_names, ins->name())) - continue; - migraphx::shape::type_t orig_type = ins->get_shape().type(); - std::vector inputs = ins->inputs(); - migraphx::transform_if( - inputs.begin(), - inputs.end(), - inputs.begin(), - [&](const auto& i) { return i->get_shape().type() == shape::fp8e4m3fnuz_type; }, - [&](const auto& i) { - return m.insert_instruction( - ins, - migraphx::make_op("convert", - {{"target_type", migraphx::to_value(target_type)}}), - i); - }); - if(inputs == ins->inputs()) - { - return; - } - auto op = ins->get_operator(); - auto attributes = op.attributes(); - if(attributes.contains("general_data_type")) - { - op = make_op(attributes["general_data_type"].to(), op.to_value()); - } - auto new_ins = m.insert_instruction(ins, op, inputs); - if(orig_type == shape::tuple_type) - { - auto orig_outs = ins->outputs(); - if(not std::all_of(orig_outs.begin(), orig_outs.end(), [&](const auto out_ins) { - return out_ins->name() == "get_tuple_elem"; - })) - MIGRAPHX_THROW("EliminateFP8: Instruction with tuple output doesn't have all its " - "usages as get_tuple_elem instruction"); - - std::transform( - orig_outs.begin(), orig_outs.end(), orig_outs.begin(), [&](const auto out_ins) { - auto gte_ins = m.insert_instruction(ins, out_ins->get_operator(), new_ins); - if(out_ins->get_shape().type() == shape::type_t::fp8e4m3fnuz_type) - { - auto gte_convert = m.insert_instruction( - ins, - make_op("convert", {{"target_type", shape::type_t::fp8e4m3fnuz_type}}), - gte_ins); - return m.replace_instruction(out_ins, gte_convert); - } - else - { - return m.replace_instruction(out_ins, gte_ins); - } - }); - } - else - { - auto convert_back_ins = m.insert_instruction( - ins, - migraphx::make_op("convert", {{"target_type", migraphx::to_value(orig_type)}}), - new_ins); - m.replace_instruction(ins, convert_back_ins); - } - } -} - -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx diff --git a/src/include/migraphx/eliminate_fp8.hpp b/src/include/migraphx/eliminate_fp8.hpp deleted file mode 100644 index c3304dd054e..00000000000 --- a/src/include/migraphx/eliminate_fp8.hpp +++ /dev/null @@ -1,52 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP -#define MIGRAPHX_GUARD_AMDMIGRAPHX_ELIMINATE_FP8_HPP - -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -struct module; - -/** -This will insert convert operators for the operators that are not implemented for FP8 dtypes - */ -struct MIGRAPHX_EXPORT eliminate_fp8 -{ - // TODO: Add all device ops as a later PR and add tests for those. - std::set op_names; - shape::type_t target_type = migraphx::shape::float_type; - std::string name() const { return "eliminate_fp8"; } - void apply(module& m) const; -}; - -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx - -#endif From b099a7da2afcbf15279f3f44d93420a7baee612e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 00:42:10 +0000 Subject: [PATCH 104/115] remove header --- src/targets/gpu/target.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index a25f43600d4..46d548a762f 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -21,7 +21,6 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include "migraphx/shape.hpp" #include #include #include @@ -53,7 +52,6 @@ #include #include #include -#include #include #include #include From 7d6e6ad74cf8dadf97bc7f445e745a0ec98174da Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 01:06:35 +0000 Subject: [PATCH 105/115] fix typo --- src/eliminate_data_type.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eliminate_data_type.cpp b/src/eliminate_data_type.cpp index 1c8efce216f..8e9bdfa39bc 100644 --- a/src/eliminate_data_type.cpp +++ b/src/eliminate_data_type.cpp @@ -108,7 +108,7 @@ void eliminate_data_type::apply(module& m) const "scatternd_add", "scatternd_mul", "scatternd_none"}; - if(unsupported_types.empty() and unsupported_types.empty()) + if(unsupported_types.empty() and unsupported_fp8_ops.empty()) { return; } From cf91c2b121bd9641b94926955fefac52f7ede565 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 14:14:53 +0000 Subject: [PATCH 106/115] add changes for the eliminate_data_type pass --- src/eliminate_data_type.cpp | 35 +++++--------------- src/include/migraphx/eliminate_data_type.hpp | 2 +- src/targets/cpu/target.cpp | 2 +- src/targets/gpu/target.cpp | 4 +-- test/eliminate_data_type_test.cpp | 14 ++++---- 5 files changed, 18 insertions(+), 39 deletions(-) diff --git a/src/eliminate_data_type.cpp b/src/eliminate_data_type.cpp index 8e9bdfa39bc..d25e8bffcb5 100644 --- a/src/eliminate_data_type.cpp +++ b/src/eliminate_data_type.cpp @@ -108,36 +108,17 @@ void eliminate_data_type::apply(module& m) const "scatternd_add", "scatternd_mul", "scatternd_none"}; - if(unsupported_types.empty() and unsupported_fp8_ops.empty()) - { + if(unsupported_types.empty()) return; - } - else if(not unsupported_fp8_ops.empty() and not unsupported_types.empty()) - { - MIGRAPHX_THROW("eliminate_data_type: specify either unsupported FP8 ops or unsupported " - "data types not both."); - } - else if(unsupported_fp8_ops.empty()) + + for(auto ins : iterator_for(m)) { - for(auto ins : iterator_for(m)) - { - if(ins->name()[0] == '@') - continue; - if(contains(skip_op_names, ins->name())) - continue; + if(ins->name()[0] == '@') + continue; + if(contains(skip_op_names, ins->name()) and not contains(unsupported_ops, ins->name())) + continue; + if(contains(unsupported_ops, "all") or contains(unsupported_ops, ins->name())) insert_convert_to_supported_type(m, ins, target_type, unsupported_types); - } - } - else - { - std::set unsupported_fp8_types = { - migraphx::shape::fp8e4m3fnuz_type}; - for(auto ins : iterator_for(m)) - { - if(not contains(unsupported_fp8_ops, ins->name())) - continue; - insert_convert_to_supported_type(m, ins, target_type, unsupported_fp8_types); - } } } diff --git a/src/include/migraphx/eliminate_data_type.hpp b/src/include/migraphx/eliminate_data_type.hpp index 52dc5f187f1..cf7a579f6a1 100644 --- a/src/include/migraphx/eliminate_data_type.hpp +++ b/src/include/migraphx/eliminate_data_type.hpp @@ -41,8 +41,8 @@ struct module; struct MIGRAPHX_EXPORT eliminate_data_type { std::set unsupported_types; - std::set unsupported_fp8_ops; shape::type_t target_type; + std::set unsupported_ops = {"all"}; std::string name() const { return "eliminate_data_type"; } void apply(module& m) const; }; diff --git a/src/targets/cpu/target.cpp b/src/targets/cpu/target.cpp index f0836f70423..ebec22a501d 100644 --- a/src/targets/cpu/target.cpp +++ b/src/targets/cpu/target.cpp @@ -70,7 +70,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti return {normalize_ops{}, rewrite_quantization{}, dead_code_elimination{}, - eliminate_data_type{unsupported_types, {}, shape::type_t::float_type}, + eliminate_data_type{unsupported_types, shape::type_t::float_type}, dead_code_elimination{}, simplify_reshapes{}, eliminate_identity{}, diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 46d548a762f..da54c264fc2 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -122,7 +122,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti simplify_qdq{}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, - eliminate_data_type{unsupported_types, {}, shape::type_t::float_type}, + eliminate_data_type{unsupported_types, shape::type_t::float_type}, simplify_reshapes{}, eliminate_identity{}, eliminate_pad{}, @@ -141,7 +141,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti prefuse_ops{}, dead_code_elimination{}, auto_contiguous{}, - eliminate_data_type{{}, unsupported_fp8_ops, shape::float_type}, + eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops}, dead_code_elimination{}, optimize_module{}, fuse_pointwise{}, diff --git a/test/eliminate_data_type_test.cpp b/test/eliminate_data_type_test.cpp index e0a4cd0ddec..23ff9c32fec 100644 --- a/test/eliminate_data_type_test.cpp +++ b/test/eliminate_data_type_test.cpp @@ -30,15 +30,13 @@ #include -void run_pass(migraphx::module& m, - std::set types, - std::set unsupported_fp8_ops = {}) +void run_pass(migraphx::module& m, std::set types) { - migraphx::run_passes(m, - {migraphx::eliminate_data_type{ - std::move(types), unsupported_fp8_ops, migraphx::shape::float_type}, - migraphx::eliminate_identity{}, - migraphx::dead_code_elimination{}}); + migraphx::run_passes( + m, + {migraphx::eliminate_data_type{std::move(types), migraphx::shape::float_type}, + migraphx::eliminate_identity{}, + migraphx::dead_code_elimination{}}); } TEST_CASE(simple) From 82f98478fded8d7b435b55dc31bd94a3ce08c590 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 14:22:49 +0000 Subject: [PATCH 107/115] add comments --- src/targets/gpu/gemm_impl.cpp | 40 +++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 057e81395af..ac052a3e081 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -22,6 +22,7 @@ * THE SOFTWARE. */ +#include #include #include #include @@ -36,6 +37,20 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { +/* +Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it +as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast interger enum +value to required type that can be used inside `common_args` generator. +*/ +struct rb_compute_type +{ + int type = 0; + rb_compute_type(rocblas_datatype t) : type(static_cast(t)) {} + rb_compute_type(rocblas_computetype t) : type(static_cast(t)) {} + operator rocblas_datatype() const { return static_cast(type); } + operator rocblas_computetype() const { return static_cast(type); } +}; + // Convert rocBLAS datatypes to equivalent Migraphx data types rocblas_datatype get_type(shape::type_t type) { @@ -185,12 +200,17 @@ struct gemm_impl { output_type = rocblas_datatype_i32_r; } - compute_type = output_type; + compute_type = rb_compute_type{output_type}; if(compute_fp32) { if(arg_type == rocblas_datatype_f16_r) compute_type = rocblas_datatype_f32_r; } + else if(arg_type == rocblas_datatype_f8_r) + { + assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r); + compute_type = rocblas_compute_type_f32; + } auto a_lens = input_shapes[0].lens(); auto b_lens = input_shapes[1].lens(); @@ -230,7 +250,6 @@ struct gemm_impl auto common_args = create_strided_batched_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex3, common_args, - rocblas_compute_type_f32, rocblas_gemm_algo_standard, solution_idx, gemm_flags); @@ -240,7 +259,6 @@ struct gemm_impl auto common_args = create_gemm_ex_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_ex3, common_args, - rocblas_compute_type_f32, rocblas_gemm_algo_standard, solution_idx, gemm_flags); @@ -254,7 +272,6 @@ struct gemm_impl auto common_args = create_strided_batched_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex, common_args, - compute_type, rocblas_gemm_algo_solution_index, solution_idx, gemm_flags); @@ -264,7 +281,6 @@ struct gemm_impl auto common_args = create_gemm_ex_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_ex, common_args, - compute_type, rocblas_gemm_algo_solution_index, solution_idx, gemm_flags); @@ -304,7 +320,6 @@ struct gemm_impl auto common_args = create_strided_batched_args_common(ctx, input_args); check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex, common_args, - compute_type, rocblas_gemm_algo_solution_index, solution_idx, rocblas_gemm_flags_check_solution_index); @@ -314,7 +329,6 @@ struct gemm_impl auto common_args = create_gemm_ex_args_common(ctx, input_args); check_valid = rocblas_invoke(&rocblas_gemm_ex, common_args, - compute_type, rocblas_gemm_algo_solution_index, solution_idx, rocblas_gemm_flags_check_solution_index); @@ -365,7 +379,8 @@ struct gemm_impl output_type, ldd, d_stride, - num_matrices); + num_matrices, + compute_type); } /** * Helper method to create that subset of a long rocBLAS argument list that is common @@ -398,7 +413,8 @@ struct gemm_impl ldc, is_3inputs ? args[3].data() : args[2].data(), output_type, - ldd); + ldd, + compute_type); } #ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API @@ -428,7 +444,6 @@ struct gemm_impl auto common_args = create_strided_batched_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, common_args, - compute_type, rocblas_gemm_algo_solution_index, gemm_flags, nullptr, @@ -438,7 +453,6 @@ struct gemm_impl auto common_sol_args = create_strided_batched_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, common_sol_args, - compute_type, rocblas_gemm_algo_solution_index, gemm_flags, solution_indices.data(), @@ -449,7 +463,6 @@ struct gemm_impl auto common_args = create_gemm_ex_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_ex_get_solutions, common_args, - compute_type, rocblas_gemm_algo_solution_index, gemm_flags, nullptr, @@ -459,7 +472,6 @@ struct gemm_impl auto common_sol_args = create_gemm_ex_args_common(ctx, input_args); rocblas_invoke(&rocblas_gemm_ex_get_solutions, common_sol_args, - compute_type, rocblas_gemm_algo_solution_index, gemm_flags, solution_indices.data(), @@ -521,7 +533,7 @@ struct gemm_impl rocblas_int c_stride = 0; rocblas_int d_stride = 0; rocblas_datatype arg_type = rocblas_datatype_f32_r; - rocblas_datatype compute_type = rocblas_datatype_f32_r; + rb_compute_type compute_type = rocblas_datatype_f32_r; rocblas_datatype output_type = rocblas_datatype_f32_r; bool strided_batched = true; bool is_3inputs = true; From a9db2bf418e3f011c7acd989b34cb4d3b1c9de4b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 14:24:12 +0000 Subject: [PATCH 108/115] fix typo --- src/targets/gpu/gemm_impl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index ac052a3e081..ca93712c97a 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -39,7 +39,7 @@ namespace gpu { /* Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it -as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast interger enum +as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum value to required type that can be used inside `common_args` generator. */ struct rb_compute_type From aeaac20f9b84e8631a84de56b1250d9235ccb3c3 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 14:27:52 +0000 Subject: [PATCH 109/115] remove else --- src/targets/gpu/gemm_impl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index ca93712c97a..4dd56db6f2c 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -206,7 +206,7 @@ struct gemm_impl if(arg_type == rocblas_datatype_f16_r) compute_type = rocblas_datatype_f32_r; } - else if(arg_type == rocblas_datatype_f8_r) + if(arg_type == rocblas_datatype_f8_r) { assert(get_type(input_shapes[1].type()) == rocblas_datatype_f8_r); compute_type = rocblas_compute_type_f32; From a196e90e4c9b413721a4b5aab5812778e2f97b80 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 18:22:18 +0000 Subject: [PATCH 110/115] disable tests that uses CK --- test/verify/gemm_2args_mm_8.cpp | 2 +- test/verify/gemm_add_broadcast2.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/verify/gemm_2args_mm_8.cpp b/test/verify/gemm_2args_mm_8.cpp index 8779ee1b8a2..54b43179e9f 100644 --- a/test/verify/gemm_2args_mm_8.cpp +++ b/test/verify/gemm_2args_mm_8.cpp @@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program> }; template struct gemm_2args_mm_8; -template struct gemm_2args_mm_8; +//template struct gemm_2args_mm_8; template struct gemm_2args_mm_8; diff --git a/test/verify/gemm_add_broadcast2.cpp b/test/verify/gemm_add_broadcast2.cpp index e35eefd9e59..c90946f8c5c 100644 --- a/test/verify/gemm_add_broadcast2.cpp +++ b/test/verify/gemm_add_broadcast2.cpp @@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program> }; template struct gemm_add_broadcast2; -template struct gemm_add_broadcast2; +//template struct gemm_add_broadcast2; template struct gemm_add_broadcast2; From 7e80f62732df7afda9be8c592d70a8b263e73707 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 18:33:02 +0000 Subject: [PATCH 111/115] formatting --- test/verify/gemm_2args_mm_8.cpp | 2 +- test/verify/gemm_add_broadcast2.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/verify/gemm_2args_mm_8.cpp b/test/verify/gemm_2args_mm_8.cpp index 54b43179e9f..982dbc003ed 100644 --- a/test/verify/gemm_2args_mm_8.cpp +++ b/test/verify/gemm_2args_mm_8.cpp @@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program> }; template struct gemm_2args_mm_8; -//template struct gemm_2args_mm_8; +// template struct gemm_2args_mm_8; template struct gemm_2args_mm_8; diff --git a/test/verify/gemm_add_broadcast2.cpp b/test/verify/gemm_add_broadcast2.cpp index c90946f8c5c..15f35ad0628 100644 --- a/test/verify/gemm_add_broadcast2.cpp +++ b/test/verify/gemm_add_broadcast2.cpp @@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program> }; template struct gemm_add_broadcast2; -//template struct gemm_add_broadcast2; +// template struct gemm_add_broadcast2; template struct gemm_add_broadcast2; From a3d4b01380117e977ec05a8e0e928e06911194f5 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 23:30:40 +0000 Subject: [PATCH 112/115] use same SHA as develop branch --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7f7a0039fef..ea0609c0117 100755 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCmSoftwarePlatform/rocMLIR@5085343bca363109ae9ebabb7ca2b65c52bc861c -DBUILD_FAT_LIBROCKCOMPILER=On +ROCmSoftwarePlatform/rocMLIR@a6880f1e6daec99876cd6a4820fbc69c57216401 -DBUILD_FAT_LIBROCKCOMPILER=On From de27b9198789d4fd6b988fd0c1198c94b660dca9 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 5 Dec 2023 23:41:34 +0000 Subject: [PATCH 113/115] use angled brackets --- src/include/migraphx/op/quant_convolution.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/op/quant_convolution.hpp b/src/include/migraphx/op/quant_convolution.hpp index dec50315f7a..5976f9163c2 100644 --- a/src/include/migraphx/op/quant_convolution.hpp +++ b/src/include/migraphx/op/quant_convolution.hpp @@ -24,10 +24,10 @@ #ifndef MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP #define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP -#include "migraphx/shape.hpp" #include #include #include +#include #include #include #include From b6250a420579dfaba34ed2ea2b804be8979af513 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 6 Dec 2023 00:17:36 +0000 Subject: [PATCH 114/115] add comment --- test/verify/run_verify.cpp | 3 ++- test/verify/test_conv_group_add.cpp | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp index c91a4aba73e..464cfb9d529 100644 --- a/test/verify/run_verify.cpp +++ b/test/verify/run_verify.cpp @@ -142,7 +142,8 @@ std::vector run_verify::run_ref(migraphx::program p, { migraphx::target t = migraphx::make_target("ref"); auto_print pp{p, t.name()}; - compile_check(p, t, c_opts); + auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{}); + compile_check(p, t, c_opts, (trace_target == "ref")); return p.eval(std::move(inputs)); } std::pair> diff --git a/test/verify/test_conv_group_add.cpp b/test/verify/test_conv_group_add.cpp index 28bd323dca9..dbb57a3630c 100644 --- a/test/verify/test_conv_group_add.cpp +++ b/test/verify/test_conv_group_add.cpp @@ -46,4 +46,5 @@ struct test_conv_group_add : verify_program> } }; template struct test_conv_group_add; +// grouped convolutions are not supported with MLIR therefore disable it // template struct test_conv_group_add; From b2542239d3d8b9cc07faadd80367ba7464eafada Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 6 Dec 2023 00:45:33 +0000 Subject: [PATCH 115/115] formatting --- test/verify/test_conv_group_add.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/test_conv_group_add.cpp b/test/verify/test_conv_group_add.cpp index dbb57a3630c..ff8747b616d 100644 --- a/test/verify/test_conv_group_add.cpp +++ b/test/verify/test_conv_group_add.cpp @@ -46,5 +46,5 @@ struct test_conv_group_add : verify_program> } }; template struct test_conv_group_add; -// grouped convolutions are not supported with MLIR therefore disable it +// grouped convolutions are not supported with MLIR therefore disable it // template struct test_conv_group_add;