Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
BRUCE11111 committed Oct 8, 2024
1 parent d8e968f commit cc0b4c1
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 40 deletions.
34 changes: 34 additions & 0 deletions include/gc/Transforms/Utils/NumericUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===-- NumericUtils.h - numeric utilities ----------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef GC_TRANSFORMS_UTILS_NEMURICUTILS_H
#define GC_TRANSFORMS_UTILS_NEMURICUTILS_H
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include <limits>
#include <stdint.h>
#include <variant>

namespace mlir {
namespace gc {

union Float32Bits {
uint32_t u;
float f;
};
uint16_t float2half(float floatValue);
float half2float(uint16_t halfValue);
uint16_t float2bfloat(float floatValue);
float bfloat2float(uint16_t bfloatBits);
std::variant<float, int64_t> numeric_limits_minimum(Type type);
std::variant<float, int64_t> numericLimitsMaximum(Type type);

} // namespace gc
} // namespace mlir

#endif
13 changes: 2 additions & 11 deletions include/gc/Transforms/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===-- VectorUtils.h - vector fusion analysis ------------------*- C++ -*-===//
//===-- VectorUtils.h - vector utilities ------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -8,6 +8,7 @@

#ifndef GC_TRANSFORMS_UTILS_VECTORUTILS_H
#define GC_TRANSFORMS_UTILS_VECTORUTILS_H
#include "gc/Transforms/Utils/NumericUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -96,16 +97,6 @@ int getNearestVectorStep(const int step);
/// prev-op, may need to use result vectortype
/// default will return the opeation result type
mlir::FailureOr<VectorType> getOperationMaxVectorType(Operation *op);
union Float32Bits {
uint32_t u;
float f;
};
uint16_t float2half(float floatValue);
float half2float(uint16_t halfValue);
uint16_t float2bfloat(float floatValue);
float bfloat2float(uint16_t bfloatBits);
std::variant<float, int64_t> numeric_limits_minimum(Type type);
std::variant<float, int64_t> numericLimitsMaximum(Type type);

template <typename T = float>
T getInitValForReduce(vector::CombiningKind kind, Type t) {
Expand Down
164 changes: 164 additions & 0 deletions lib/gc/Transforms/Utils/NumericUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//===- NumericUtils.cpp - numeric utilities ---------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "gc/Transforms/Utils/NumericUtils.h"

namespace mlir {
namespace gc {

const uint32_t kF32MantiBits = 23;
const uint32_t kF32HalfMantiBitDiff = 13;
const uint32_t kF32HalfBitDiff = 16;
const Float32Bits kF32Magic = {113 << kF32MantiBits};
const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits;
const uint32_t kF32BfMantiBitDiff = 16;

/// Constructs the 16 bit representation for a half precision value from a float
/// value. This implementation is adapted from Eigen.
uint16_t float2half(float floatValue) {
const Float32Bits inf = {255 << kF32MantiBits};
const Float32Bits f16max = {(127 + 16) << kF32MantiBits};
const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1)
<< kF32MantiBits};
uint32_t signMask = 0x80000000u;
uint16_t halfValue = static_cast<uint16_t>(0x0u);
Float32Bits f;
f.f = floatValue;
uint32_t sign = f.u & signMask;
f.u ^= sign;

if (f.u >= f16max.u) {
const uint32_t halfQnan = 0x7e00;
const uint32_t halfInf = 0x7c00;
// Inf or NaN (all exponent bits set).
halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf
} else {
// (De)normalized number or zero.
if (f.u < kF32Magic.u) {
// The resulting FP16 is subnormal or zero.
//
// Use a magic value to align our 10 mantissa bits at the bottom of the
// float. As long as FP addition is round-to-nearest-even this works.
f.f += denormMagic.f;

halfValue = static_cast<uint16_t>(f.u - denormMagic.u);
} else {
uint32_t mantOdd =
(f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd.

// Update exponent, rounding bias part 1. The following expressions are
// equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) +
// 0xfff`, but without arithmetic overflow.
f.u += 0xc8000fffU;
// Rounding bias part 2.
f.u += mantOdd;
halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff);
}
}

halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff);
return halfValue;
}

/// Converts the 16 bit representation of a half precision value to a float
/// value. This implementation is adapted from Eigen.
float half2float(uint16_t halfValue) {
const uint32_t shiftedExp =
0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift.

// Initialize the float representation with the exponent/mantissa bits.
Float32Bits f = {
static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)};
const uint32_t exp = shiftedExp & f.u;
f.u += kF32HalfExpAdjust; // Adjust the exponent

// Handle exponent special cases.
if (exp == shiftedExp) {
// Inf/NaN
f.u += kF32HalfExpAdjust;
} else if (exp == 0) {
// Zero/Denormal?
f.u += 1 << kF32MantiBits;
f.f -= kF32Magic.f;
}

f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit.
return f.f;
}

// Constructs the 16 bit representation for a bfloat value from a float value.
// This implementation is adapted from Eigen.
uint16_t float2bfloat(float floatValue) {
if (std::isnan(floatValue))
return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0;

Float32Bits floatBits;
floatBits.f = floatValue;
uint16_t bfloatBits;

// Least significant bit of resulting bfloat.
uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1;
uint32_t roundingBias = 0x7fff + lsb;
floatBits.u += roundingBias;
bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff);
return bfloatBits;
}

// Converts the 16 bit representation of a bfloat value to a float value. This
// implementation is adapted from Eigen.
float bfloat2float(uint16_t bfloatBits) {
Float32Bits floatBits;
floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff;
return floatBits.f;
}

std::variant<float, int64_t> numeric_limits_minimum(Type type) {
Type t1 = getElementTypeOrSelf(type);
if (t1.isF32()) {
return -std::numeric_limits<float>::infinity();
} else if (t1.isBF16()) {
return bfloat2float(float2bfloat(-std::numeric_limits<float>::infinity()));
} else if (t1.isF16()) {
return (float)half2float(
float2half(-std::numeric_limits<float>::infinity()));
} else if (t1.isSignedInteger(8)) {
return int64_t(-128);
} else if (t1.isSignedInteger(32)) {
return int64_t(std::numeric_limits<int32_t>::min());
} else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) {
return int64_t(0);
} else {
llvm_unreachable("unsupported data type");
return (int64_t)0;
}
}

std::variant<float, int64_t> numericLimitsMaximum(Type type) {
Type t1 = getElementTypeOrSelf(type);
if (t1.isF32()) {
return std::numeric_limits<float>::infinity();
} else if (t1.isBF16()) {
return bfloat2float(float2bfloat(std::numeric_limits<float>::infinity()));
} else if (t1.isF16()) {
return (float)half2float(
float2half(std::numeric_limits<float>::infinity()));
} else if (t1.isSignedInteger(8)) {
return int64_t(127);
} else if (t1.isSignedInteger(32)) {
return int64_t(std::numeric_limits<int32_t>::max());
} else if (t1.isSignlessInteger(8)) {
return int64_t(255);
} else if (t1.isSignedInteger(32)) {
return int64_t(std::numeric_limits<uint32_t>::max());
} else {
llvm_unreachable("unsupported data type");
return (int64_t)0;
}
}

} // namespace gc
} // namespace mlir
45 changes: 16 additions & 29 deletions lib/gc/Transforms/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//===- VectorUtils.cpp - analysis vector ops --------------------*- C++ -*-===//
//===- VectorUtils.cpp - vector utilities -----------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "gc/Transforms/Utils/VectorUtils.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"

namespace mlir {
namespace gc {

Expand Down Expand Up @@ -37,13 +37,10 @@ OPPRIORITY operator++(OPPRIORITY &c) {
LogicalResult moveFront(Operation *op, IRRewriter &rewriter) {
Operation *backOperation = nullptr;
// check all the operand is block argument
bool allBlockArgs = true;
for (auto operand : op->getOperands()) {
if (!isa<BlockArgument>(operand)) {
allBlockArgs = false;
break;
}
}
bool allBlockArgs = llvm::all_of(op->getOperands(), [](Value operand) {
return isa<BlockArgument>(operand);
});

if (allBlockArgs) {
moveOpBeginingOfBlock(op, rewriter);
return success();
Expand Down Expand Up @@ -157,26 +154,16 @@ void getOperationPriority(
candidateOps.push(std::make_pair(op, OPPRIORITY::FIRST));
return;
})
.Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp extractOp) {
candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND));
return;
})
.Case<tensor::EmptyOp>([&](tensor::EmptyOp emptyOp) {
candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND));
return;
})
.Case<tensor::InsertSliceOp>([&](tensor::InsertSliceOp insertOp) {
candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND));
return;
})
.Case<vector::TransferReadOp>([&](vector::TransferReadOp readOp) {
candidateOps.push(std::make_pair(op, OPPRIORITY::LAST));
return;
})
.Case<vector::TransferWriteOp>([&](vector::TransferWriteOp writeOp) {
candidateOps.push(std::make_pair(op, OPPRIORITY::LAST));
return;
})
.Case<tensor::EmptyOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(
[&](tensor::EmptyOp emptyOp) {
candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND));
return;
})
.Case<vector::TransferWriteOp, vector::TransferReadOp>(
[&](vector::TransferWriteOp writeOp) {
candidateOps.push(std::make_pair(op, OPPRIORITY::LAST));
return;
})
.Case<vector::BroadcastOp>([&](vector::BroadcastOp bcOp) {
candidateOps.push(std::make_pair(op, OPPRIORITY::THIRD));
return;
Expand Down

0 comments on commit cc0b4c1

Please sign in to comment.