-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d8e968f
commit cc0b4c1
Showing
4 changed files
with
216 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
//===-- NumericUtils.h - numeric utilities ----------------------*- C++ -*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef GC_TRANSFORMS_UTILS_NEMURICUTILS_H | ||
#define GC_TRANSFORMS_UTILS_NEMURICUTILS_H | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/IR/Types.h" | ||
#include <limits> | ||
#include <stdint.h> | ||
#include <variant> | ||
|
||
namespace mlir { | ||
namespace gc { | ||
|
||
union Float32Bits { | ||
uint32_t u; | ||
float f; | ||
}; | ||
uint16_t float2half(float floatValue); | ||
float half2float(uint16_t halfValue); | ||
uint16_t float2bfloat(float floatValue); | ||
float bfloat2float(uint16_t bfloatBits); | ||
std::variant<float, int64_t> numeric_limits_minimum(Type type); | ||
std::variant<float, int64_t> numericLimitsMaximum(Type type); | ||
|
||
} // namespace gc | ||
} // namespace mlir | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
//===- NumericUtils.cpp - numeric utilities ---------------------*- C++ -*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#include "gc/Transforms/Utils/NumericUtils.h" | ||
|
||
namespace mlir { | ||
namespace gc { | ||
|
||
const uint32_t kF32MantiBits = 23; | ||
const uint32_t kF32HalfMantiBitDiff = 13; | ||
const uint32_t kF32HalfBitDiff = 16; | ||
const Float32Bits kF32Magic = {113 << kF32MantiBits}; | ||
const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; | ||
const uint32_t kF32BfMantiBitDiff = 16; | ||
|
||
/// Constructs the 16 bit representation for a half precision value from a float | ||
/// value. This implementation is adapted from Eigen. | ||
uint16_t float2half(float floatValue) { | ||
const Float32Bits inf = {255 << kF32MantiBits}; | ||
const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; | ||
const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) | ||
<< kF32MantiBits}; | ||
uint32_t signMask = 0x80000000u; | ||
uint16_t halfValue = static_cast<uint16_t>(0x0u); | ||
Float32Bits f; | ||
f.f = floatValue; | ||
uint32_t sign = f.u & signMask; | ||
f.u ^= sign; | ||
|
||
if (f.u >= f16max.u) { | ||
const uint32_t halfQnan = 0x7e00; | ||
const uint32_t halfInf = 0x7c00; | ||
// Inf or NaN (all exponent bits set). | ||
halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf | ||
} else { | ||
// (De)normalized number or zero. | ||
if (f.u < kF32Magic.u) { | ||
// The resulting FP16 is subnormal or zero. | ||
// | ||
// Use a magic value to align our 10 mantissa bits at the bottom of the | ||
// float. As long as FP addition is round-to-nearest-even this works. | ||
f.f += denormMagic.f; | ||
|
||
halfValue = static_cast<uint16_t>(f.u - denormMagic.u); | ||
} else { | ||
uint32_t mantOdd = | ||
(f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. | ||
|
||
// Update exponent, rounding bias part 1. The following expressions are | ||
// equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + | ||
// 0xfff`, but without arithmetic overflow. | ||
f.u += 0xc8000fffU; | ||
// Rounding bias part 2. | ||
f.u += mantOdd; | ||
halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff); | ||
} | ||
} | ||
|
||
halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff); | ||
return halfValue; | ||
} | ||
|
||
/// Converts the 16 bit representation of a half precision value to a float | ||
/// value. This implementation is adapted from Eigen. | ||
float half2float(uint16_t halfValue) { | ||
const uint32_t shiftedExp = | ||
0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. | ||
|
||
// Initialize the float representation with the exponent/mantissa bits. | ||
Float32Bits f = { | ||
static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; | ||
const uint32_t exp = shiftedExp & f.u; | ||
f.u += kF32HalfExpAdjust; // Adjust the exponent | ||
|
||
// Handle exponent special cases. | ||
if (exp == shiftedExp) { | ||
// Inf/NaN | ||
f.u += kF32HalfExpAdjust; | ||
} else if (exp == 0) { | ||
// Zero/Denormal? | ||
f.u += 1 << kF32MantiBits; | ||
f.f -= kF32Magic.f; | ||
} | ||
|
||
f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. | ||
return f.f; | ||
} | ||
|
||
// Constructs the 16 bit representation for a bfloat value from a float value. | ||
// This implementation is adapted from Eigen. | ||
uint16_t float2bfloat(float floatValue) { | ||
if (std::isnan(floatValue)) | ||
return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0; | ||
|
||
Float32Bits floatBits; | ||
floatBits.f = floatValue; | ||
uint16_t bfloatBits; | ||
|
||
// Least significant bit of resulting bfloat. | ||
uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; | ||
uint32_t roundingBias = 0x7fff + lsb; | ||
floatBits.u += roundingBias; | ||
bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff); | ||
return bfloatBits; | ||
} | ||
|
||
// Converts the 16 bit representation of a bfloat value to a float value. This | ||
// implementation is adapted from Eigen. | ||
float bfloat2float(uint16_t bfloatBits) { | ||
Float32Bits floatBits; | ||
floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff; | ||
return floatBits.f; | ||
} | ||
|
||
std::variant<float, int64_t> numeric_limits_minimum(Type type) { | ||
Type t1 = getElementTypeOrSelf(type); | ||
if (t1.isF32()) { | ||
return -std::numeric_limits<float>::infinity(); | ||
} else if (t1.isBF16()) { | ||
return bfloat2float(float2bfloat(-std::numeric_limits<float>::infinity())); | ||
} else if (t1.isF16()) { | ||
return (float)half2float( | ||
float2half(-std::numeric_limits<float>::infinity())); | ||
} else if (t1.isSignedInteger(8)) { | ||
return int64_t(-128); | ||
} else if (t1.isSignedInteger(32)) { | ||
return int64_t(std::numeric_limits<int32_t>::min()); | ||
} else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { | ||
return int64_t(0); | ||
} else { | ||
llvm_unreachable("unsupported data type"); | ||
return (int64_t)0; | ||
} | ||
} | ||
|
||
std::variant<float, int64_t> numericLimitsMaximum(Type type) { | ||
Type t1 = getElementTypeOrSelf(type); | ||
if (t1.isF32()) { | ||
return std::numeric_limits<float>::infinity(); | ||
} else if (t1.isBF16()) { | ||
return bfloat2float(float2bfloat(std::numeric_limits<float>::infinity())); | ||
} else if (t1.isF16()) { | ||
return (float)half2float( | ||
float2half(std::numeric_limits<float>::infinity())); | ||
} else if (t1.isSignedInteger(8)) { | ||
return int64_t(127); | ||
} else if (t1.isSignedInteger(32)) { | ||
return int64_t(std::numeric_limits<int32_t>::max()); | ||
} else if (t1.isSignlessInteger(8)) { | ||
return int64_t(255); | ||
} else if (t1.isSignedInteger(32)) { | ||
return int64_t(std::numeric_limits<uint32_t>::max()); | ||
} else { | ||
llvm_unreachable("unsupported data type"); | ||
return (int64_t)0; | ||
} | ||
} | ||
|
||
} // namespace gc | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters