Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating generic_float struct for adding datatypes #3522

Merged
merged 64 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
c51c1ce
first pass at integrating generic float
richagadgil Oct 10, 2024
134b408
fix namespaces
richagadgil Oct 10, 2024
d4fa6eb
fix mantissa
richagadgil Oct 10, 2024
0b60841
refactor
richagadgil Oct 11, 2024
7a646f1
refactor
richagadgil Oct 11, 2024
ebe819b
add fp
richagadgil Oct 11, 2024
379a77a
fixed generic float class
richagadgil Oct 14, 2024
174384c
add fp32 test
richagadgil Oct 14, 2024
787b651
remove import
richagadgil Oct 14, 2024
1d1fa1c
update tests
richagadgil Oct 15, 2024
1791092
fp16 tests that work
richagadgil Oct 17, 2024
a2eb005
update tests
richagadgil Oct 18, 2024
ff8ffc7
updated fp16 and fp32 tests
richagadgil Oct 18, 2024
e36fd65
half tests
richagadgil Oct 22, 2024
9ac4e2a
underflow and overflow tests
richagadgil Oct 22, 2024
f05fd31
generate map
richagadgil Oct 22, 2024
cb4d92d
add more tests
richagadgil Oct 22, 2024
0cc1946
fix names
richagadgil Oct 22, 2024
85a761b
update tests
richagadgil Oct 23, 2024
65cf9ae
remove and
richagadgil Oct 24, 2024
fbabf54
disable warning
richagadgil Oct 24, 2024
549f5e6
fix tidy warning
richagadgil Oct 24, 2024
d302e5d
migraphx py fix
richagadgil Oct 25, 2024
8d475e3
add increments
richagadgil Oct 25, 2024
a0fd055
fix warnings
richagadgil Oct 25, 2024
41379fe
disable duplicate branch warning
richagadgil Oct 25, 2024
0c29c7b
add countzero_std
richagadgil Oct 28, 2024
4b012a8
ci error
richagadgil Oct 28, 2024
dbaa3a8
simplify countl
richagadgil Oct 28, 2024
b2bd2a0
fix ci
richagadgil Oct 28, 2024
6f328f0
src
richagadgil Oct 29, 2024
e6d9763
remove flag
richagadgil Oct 29, 2024
6538050
hide abi warning
richagadgil Oct 29, 2024
4e96d4d
revert changes
richagadgil Oct 29, 2024
ef11f1f
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
e4a25bd
change half in tests
richagadgil Oct 29, 2024
3354c6e
Update generic_float.hpp
richagadgil Oct 29, 2024
6de079b
format
richagadgil Oct 29, 2024
7750874
Merge branch 'develop' into generic_float
richagadgil Oct 29, 2024
801f485
Merge branch 'develop' into generic_float
causten Oct 30, 2024
33e2c8d
fix bug
richagadgil Oct 30, 2024
9bb7198
Merge branch 'generic_float' of github.com:ROCm/AMDMIGraphX into gene…
richagadgil Oct 30, 2024
b3c345d
fix err
richagadgil Oct 30, 2024
03df6f9
edits
richagadgil Oct 31, 2024
ad817b2
tidy and format
richagadgil Oct 31, 2024
898417b
tidy etc
richagadgil Oct 31, 2024
aa5b9c9
gf
richagadgil Oct 31, 2024
6f72370
fix tidy errs
richagadgil Nov 1, 2024
0aab1a0
bf16 changes
richagadgil Nov 4, 2024
a337b16
Update generic_float.hpp
richagadgil Nov 4, 2024
894ed7f
Update float32.cpp
richagadgil Nov 4, 2024
4895a68
fix tidy
richagadgil Nov 4, 2024
b129bd5
format
richagadgil Nov 4, 2024
0463266
change tidy warnings
richagadgil Nov 5, 2024
2e3bd25
tidy
richagadgil Nov 5, 2024
ff3566e
Merge branch 'develop' into generic_float
richagadgil Nov 5, 2024
c02f3e3
windows build fix
richagadgil Nov 6, 2024
99802b9
Merge branch 'generic_float' of github.com:ROCm/AMDMIGraphX into gene…
richagadgil Nov 6, 2024
2db6e41
windows build
richagadgil Nov 6, 2024
b195514
replace w gnu flag
richagadgil Nov 6, 2024
5754c1c
align
richagadgil Nov 6, 2024
34554fb
cmake
richagadgil Nov 6, 2024
56cdd29
readd mvsc
richagadgil Nov 6, 2024
9ae05ae
redo compile options
richagadgil Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 268 additions & 0 deletions src/include/migraphx/generic_float.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template<unsigned int N>
constexpr unsigned int all_ones() noexcept
{
return (1 << N) - 1;

Check warning on line 31 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

use of a signed integer operand with a binary bitwise operator [hicpp-signed-bitwise,-warnings-as-errors]
}

struct float32_parts
{
unsigned int mantissa : 23;
unsigned int exponent : 8;
unsigned int sign : 1;

static constexpr unsigned int mantissa_width()
{
return 23;
}

static constexpr unsigned int max_exponent()
{
return all_ones<8>();
}

static constexpr int exponent_bias()
{
return all_ones<7>();
}

constexpr float to_float() const noexcept
{
return migraphx::bit_cast<float>(*this);
}
};

constexpr float32_parts get_parts(float f)
{
return migraphx::bit_cast<float32_parts>(f);
}

template<unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0>
struct generic_float
{
unsigned int mantissa : MantissaSize;
unsigned int exponent : ExponentSize;
unsigned int sign : 1;

static constexpr int exponent_bias()
{
return all_ones<ExponentSize - 1>();
}

explicit generic_float(float f = 0.0) noexcept
{
from_float(get_parts(f));
}

constexpr float to_float() const noexcept
{
float32_parts f{};
f.sign = sign;
f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize);
if(exponent == all_ones<ExponentSize>())
{
f.exponent = float32_parts::max_exponent();
}
else
{
constexpr const auto diff = float32_parts::exponent_bias() - exponent_bias();
f.exponent = exponent + diff;
}
return f.to_float();
}

constexpr void from_float(float32_parts f) noexcept
{
sign = f.sign;
mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize);

if(f.exponent == 0)
{
exponent = 0;
}
else if(f.exponent == float32_parts::max_exponent())
{
exponent = all_ones<ExponentSize>();
}
else
{
constexpr const int diff = float32_parts::exponent_bias() - exponent_bias();
auto e = int(f.exponent) - diff;
if(e >= all_ones<ExponentSize>())
{
exponent = all_ones<ExponentSize>();
mantissa = 0;
}
else if(e < 0)
{
exponent = 0;
mantissa = 0;
}
else
{
exponent = f.exponent - diff;
}
}

exponent = std::min(f.exponent, all_ones<ExponentSize>());
}

constexpr bool is_normal() const noexcept
{
return exponent != all_ones<ExponentSize>() and exponent != 0;
}

constexpr bool is_inf() const noexcept
{
return exponent == all_ones<ExponentSize>() and mantissa == 0;
}

constexpr bool is_nan() const noexcept
{
return exponent == all_ones<ExponentSize>() and mantissa != 0;
}

constexpr bool is_finite() const noexcept
{
return exponent != all_ones<ExponentSize>();
}

constexpr operator float() const noexcept
{
return this->to_float();
}

static constexpr generic_float infinity()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>();
return x;
}

static constexpr generic_float snan()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>();
x.mantissa = 1 << (MantissaSize - 2);

Check warning on line 172 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

use of a signed integer operand with a binary bitwise operator [hicpp-signed-bitwise,-warnings-as-errors]
return x;
}

static constexpr generic_float qnan()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>();
x.mantissa = 1 << (MantissaSize - 1);

Check warning on line 180 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

use of a signed integer operand with a binary bitwise operator [hicpp-signed-bitwise,-warnings-as-errors]
return x;
}

static constexpr generic_float min()
{
generic_float x{};
x.exponent = 1;
x.mantissa = 0;
return x;
}

static constexpr generic_float denorm_min()
{
generic_float x{};
x.exponent = 0;
x.mantissa = 1;
x.sign = 0;
return x;
}

static constexpr generic_float lowest()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>() - 1;
x.mantissa = all_ones<MantissaSize>();
x.sign = 1;
return x;
}

static constexpr generic_float max()
{
generic_float x{};
x.exponent = all_ones<ExponentSize>() - 1;
x.mantissa = all_ones<MantissaSize>();
x.sign = 0;
return x;
}

static constexpr generic_float epsilon()
{
generic_float x{1.0};
x.mantissa++;
return generic_float{x.to_float() - 1.0f};
}
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \
constexpr generic_float& operator op(const generic_float& rhs) \
{ \
float self = *this; \
float frhs = rhs; \
self op frhs; \
*this = generic_float(self); \
return *this; \
}
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=)
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \
friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \
{ \
return generic_float(float(x) op float(y)); \
}
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(<)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(<=)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(>)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(>=)

friend constexpr generic_float operator==(const generic_float& x, const generic_float& y)
{
if (not x.is_finite() or not y.is_finite())
return false;
return std::tie(x.mantissa, x.exponent, x.sign) == std::tie(y.mantissa, y.exponent, y.sign);
}

friend constexpr generic_float operator!=(const generic_float& x, const generic_float& y)
{
return not(x == y);
}
};

}

Check warning on line 267 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

namespace 'MIGRAPHX_INLINE_NS' not terminated with a closing comment [llvm-namespace-comment,-warnings-as-errors]
}

Check warning on line 268 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

namespace 'migraphx' not terminated with a closing comment [llvm-namespace-comment,-warnings-as-errors]
68 changes: 32 additions & 36 deletions src/include/migraphx/half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
#include <half/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generic_float.hpp>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the ifdef on L45 do? If that is removed, could we remove the <half/half.hpp> include here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think we should remove <half/half.hpp> includes everywhere from migraphx, and we can remove the ifdef from below.


namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

using half = half_float::half;
using half = migraphx::generic_float<10,5>;

namespace detail {
template <class T>
Expand All @@ -53,61 +54,56 @@
template <class T>
using deduce = typename detail::deduce<T>::type;

template <class T>
struct is_floating_point : std::false_type {};

template <unsigned int E, unsigned int M, unsigned int F>
struct is_floating_point<generic_float<E, M, F>> : std::true_type {};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

namespace std {

template <class T>
struct common_type<migraphx::half, T> : std::common_type<float, T> // NOLINT
template<unsigned int E, unsigned int M, unsigned int F>
class numeric_limits<migraphx::generic_float<E, M, F>>

Check warning on line 69 in src/include/migraphx/half.hpp

View workflow job for this annotation

GitHub Actions / tidy

modification of 'std' namespace can result in undefined behavior [cert-dcl58-cpp,-warnings-as-errors]
{
};
public:
static constexpr bool has_infinity = false;
static constexpr migraphx::generic_float<E, M, F> epsilon() { return migraphx::generic_float<E, M, F>::epsilon(); }

template <class T>
struct common_type<T, migraphx::half> : std::common_type<float, T> // NOLINT
{
};
static constexpr migraphx::generic_float<E, M, F> quiet_NaN() { return migraphx::generic_float<E, M, F>::quiet_NaN(); }

Check warning on line 75 in src/include/migraphx/half.hpp

View workflow job for this annotation

GitHub Actions / tidy

invalid case style for constexpr method 'quiet_NaN' [readability-identifier-naming,-warnings-as-errors]

template <>
struct common_type<migraphx::fp8::fp8e4m3fnuz, migraphx::half>
{
using type = float;
};
static constexpr migraphx::generic_float<E, M, F> max() { return migraphx::generic_float<E, M, F>::max(); }

template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fnuz>
{
using type = float;
};
static constexpr migraphx::generic_float<E, M, F> min() { return migraphx::generic_float<E, M, F>::min(); }

template <>
struct common_type<migraphx::fp8::fp8e4m3fn, migraphx::half>
{
using type = float;
};
static constexpr migraphx::generic_float<E, M, F> lowest() { return migraphx::generic_float<E, M, F>::lowest(); }

template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fn>
{
using type = float;
};

template <>
struct common_type<migraphx::fp8::fp8e5m2, migraphx::half>
template<unsigned int E, unsigned int M, unsigned int F, class T>
struct common_type<migraphx::generic_float<E, M, F>, T> : std::common_type<float, T> // NOLINT
{
using type = float;
};

template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e5m2>
template<unsigned int E, unsigned int M, unsigned int F, class T>
struct common_type<T, migraphx::generic_float<E, M, F>> : std::common_type<float, T> // NOLINT
{
using type = float;
};

template <>
struct common_type<migraphx::half, migraphx::half>
template<unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
struct common_type<migraphx::generic_float<E, M, F>, migraphx::fp8::float8<T, FNUZ>> : std::common_type<float, float>

Check warning on line 96 in src/include/migraphx/half.hpp

View workflow job for this annotation

GitHub Actions / tidy

modification of 'std' namespace can result in undefined behavior [cert-dcl58-cpp,-warnings-as-errors]
{};

template<unsigned int E, unsigned int M, unsigned int F, migraphx::fp8::f8_type T, bool FNUZ>
struct common_type<migraphx::fp8::float8<T, FNUZ>, migraphx::generic_float<E, M, F>> : std::common_type<float, float>

Check warning on line 100 in src/include/migraphx/half.hpp

View workflow job for this annotation

GitHub Actions / tidy

modification of 'std' namespace can result in undefined behavior [cert-dcl58-cpp,-warnings-as-errors]
{};

template<unsigned int E, unsigned int M, unsigned int F>
struct common_type<migraphx::generic_float<E, M, F>, migraphx::generic_float<E, M, F>>

Check warning on line 104 in src/include/migraphx/half.hpp

View workflow job for this annotation

GitHub Actions / tidy

modification of 'std' namespace can result in undefined behavior [cert-dcl58-cpp,-warnings-as-errors]
{
using type = migraphx::half;
using type = migraphx::generic_float<E, M, F>;
};

} // namespace std
Expand Down
Loading