Skip to content

Commit

Permalink
fix bf16 failed
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Apr 19, 2024
1 parent 7414b51 commit 3d293a4
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ void LaunchTransposeKernel(cudaStream_t stream, T* input,
template void LaunchTransposeKernel<float>(cudaStream_t stream, float* input,
std::vector<int64_t> input_dims,
float* output);

template void LaunchTransposeKernel<Eigen::half>(
cudaStream_t stream, Eigen::half* input, std::vector<int64_t> input_dims,
Eigen::half* output);
template void LaunchTransposeKernel<Eigen::bfloat16>(
cudaStream_t stream, Eigen::bfloat16* input,
std::vector<int64_t> input_dims, Eigen::bfloat16* output);
#endif

} // namespace ral
Expand Down
4 changes: 3 additions & 1 deletion tao_compiler/mlir/custom_ops/transpose_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ void ral_transpose(ExecutionContext* ctx, void* stream_handle,

LaunchTransposeKernel<T>(stream, d_in, input_dims, d_out);
}

DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<float, 2>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<float, 3>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::half, 2>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::half, 3>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::bfloat16, 2>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::bfloat16, 3>);
#endif

} // namespace ral
Expand Down
6 changes: 4 additions & 2 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,12 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
// optimization. Then this pass will be enabled by default.
pm.addNestedPass<FuncOp>(disc_ral::createForLoopUnrollInterleavePass());
}
pm.addNestedPass<FuncOp>(arith::createArithExpandOpsPass());
mlir::arith::ArithExpandOpsOptions arith_option;
arith_option.includeBf16 = true;
pm.addNestedPass<FuncOp>(arith::createArithExpandOpsPass(arith_option));
// Origin: https://reviews.llvm.org/D147585
// Should be removed after rebasing to the latest llvm head
pm.addNestedPass<FuncOp>(disc_ral::createDiscBF16ExpansionPass());
// pm.addNestedPass<FuncOp>(disc_ral::createDiscBF16ExpansionPass());
pm.addNestedPass<FuncOp>(mlir::memref::createFoldMemRefAliasOpsPass());

// Flatten multi dim memref accesses to its 1D format to enable more
Expand Down
63 changes: 48 additions & 15 deletions tao_compiler/mlir/disc/transforms/disc_bf16_expansion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type resultETy = getElementTypeOrSelf(resultTy);

if (!operandETy.isBF16() || !resultETy.isF32()) {
return failure();
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
}

Type i16Ty = b.getI16Type();
Expand Down Expand Up @@ -98,21 +98,9 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type resultETy = getElementTypeOrSelf(resultTy);

if (!operandETy.isF32() || !resultETy.isBF16()) {
return failure();
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
}

#if defined(TAO_AARCH64)
if (isBFCVTEnabled()) {
auto intrinsicName =
StringAttr::get(rewriter.getContext(), "llvm.aarch64.neon.bfcvt");
SmallVector<Value, 2> args;
args.push_back(operand);
rewriter.replaceOpWithNewOp<LLVM::CallIntrinsicOp>(op, resultETy,
intrinsicName, args);
return success();
}
#endif

Type i1Ty = b.getI1Type();
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Expand All @@ -125,8 +113,53 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}

Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);

Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
Value expMask =
createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
Value expMax =
createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);

// Grab the sign bit.
Value sign = b.create<arith::ShRUIOp>(bitcast, c31);

// Our mantissa rounding value depends on the sign bit and the last
// truncated bit.
Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
cManRound = b.create<arith::SubIOp>(cManRound, sign);

// Grab out the mantissa and directly apply rounding.
Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
Value manRound = b.create<arith::AddIOp>(man, cManRound);

// Grab the overflow bit and shift right if we overflow.
Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);

// Grab the exponent and round using the mantissa's carry bit.
Value exp = b.create<arith::AndIOp>(bitcast, expMask);
Value expCarry = b.create<arith::AddIOp>(exp, manRound);
expCarry = b.create<arith::AndIOp>(expCarry, expMask);

// If the exponent is saturated, we keep the max value.
Value expCmp =
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);

// If the exponent is max and we rolled over, keep the old mantissa.
Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
man = b.create<arith::SelectOp>(keepOldMan, man, manNew);

// Assemble the now rounded f32 value (as an i32).
Value rounded = b.create<arith::ShLIOp>(sign, c31);
rounded = b.create<arith::OrIOp>(rounded, exp);
rounded = b.create<arith::OrIOp>(rounded, man);

Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value shr = b.create<arith::ShRUIOp>(bitcast, c16);
Value shr = b.create<arith::ShRUIOp>(rounded, c16);
Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
Value result = b.create<arith::BitcastOp>(resultTy, trunc);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License.
lmhlo::AbsOp, lmhlo::CeilOp, lmhlo::FloorOp, lmhlo::ConvertOp, lmhlo::CosineOp,
lmhlo::ExpOp, lmhlo::LogOp, lmhlo::NegOp, lmhlo::RsqrtOp, lmhlo::SqrtOp,
lmhlo::SignOp, lmhlo::TanhOp, lmhlo::LogisticOp, lmhlo::Log1pOp,
lmhlo::SineOp, lmhlo::RoundOp, lmhlo::RoundNearestEvenOp,
lmhlo::SineOp, lmhlo::RoundOp, lmhlo::RoundNearestEvenOp, lmhlo::BitcastConvertOp,

// Binary Elementwise Ops
lmhlo::AddOp, lmhlo::DivOp, lmhlo::MaxOp, lmhlo::MinOp, lmhlo::MulOp,
Expand Down
13 changes: 12 additions & 1 deletion tao_compiler/mlir/disc/transforms/disc_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,18 @@ LogicalResult getTypeEncoding(MLIRContext* ctx, Type t, StrT& out) {
out.append(Twine("i").concat(Twine(int_type.getWidth())).str());
}
} else if (auto fp_type = t.dyn_cast<FloatType>()) {
out.append(Twine("f").concat(Twine(fp_type.getWidth())).str());
if (fp_type.isF16()) {
out.append("f16");
} else if (fp_type.isBF16()) {
out.append("bf16");
} else if (fp_type.isF32()) {
out.append("f32");
} else if (fp_type.isF64()) {
out.append("f64");
} else {
return failure();
}
// out.append(Twine("f").concat(Twine(fp_type.getWidth())).str());
} else if (auto ctx_type = t.dyn_cast<RalExecutionContextType>() ||
t == llvm_i8ptr_type || t == llvm_ptr_type) {
out.append("pvoid");
Expand Down
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/transforms/fusion_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,7 @@ bool BaseCpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,

bool BaseGpuFusionStrategy::isFusible(Operation* op) {
// Only rank-2 tensor -> rank-1 tensor reduction are supported now.
// if (isa<lmhlo::ReduceOp>(op) && isRank2ScalarReduction(op)) return false;
if (isa<lmhlo::ReduceOp>(op) &&
(!isRank2RowReduction(op) && !isRank2ColReduction(op)))
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ bool findValidReductionOps(FusionPatternBase& target,
}

bool StitchGpuFusionStrategy::isFusible(Operation* op) {
// if (isa<lmhlo::ReduceOp>(op) && isRank2ScalarReduction(op)) return false;
if (isa<lmhlo::TransposeOp>(op) && isRank2or3Transpose(op)) return false;
return true;
}
Expand Down
17 changes: 14 additions & 3 deletions tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1239,10 +1239,11 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,

Value zero_element;
if (result_elem_type.isF16() || result_elem_type.isF32() ||
result_elem_type.isF64()) {
result_elem_type.isF64() || result_elem_type.isBF16()) {
auto float_result_elem_type = result_elem_type.cast<FloatType>();
zero_element = b->create<arith::ConstantFloatOp>(
loc, APFloat::getZero(float_result_elem_type.getFloatSemantics()),
loc,
APFloat::getZero(float_result_elem_type.getFloatSemantics(), false),
float_result_elem_type);
} else if (result_elem_type.isSignlessInteger() ||
result_elem_type.isSignedInteger() ||
Expand Down Expand Up @@ -1312,7 +1313,17 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,

b->setInsertionPointToEnd(&if_inbound_ops[i].getElseRegion().front());
if (i == num_input_operands - 1) {
b->create<scf::YieldOp>(loc, zero_element); // expect never used
input_index[axis] = b->create<arith::SubIOp>(loc, out_idx, low_bound);
auto operand_memref = op.getOperand(i);
auto ret_value =
check_cache ? createLoadOrUseCachedValue(
loc, b, op.getOperation(), operand_memref,
input_index, b->saveInsertionPoint(), lower_config)
: createMaySpecificLoad(*b, loc, op.getOperation(),
operand_memref, input_index,
lower_config);
b->create<scf::YieldOp>(loc, ret_value);
// b->create<scf::YieldOp>(loc, zero_element); // expect never used
} else {
b->create<scf::YieldOp>(loc, if_inbound_ops[i + 1].getResults());
}
Expand Down
9 changes: 9 additions & 0 deletions tao_compiler/mlir/ral/context/base/base_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ void ral_base_cuda_send_output_0d(ExecutionContext* ctx, int64_t output_idx,
TAO_RAL_API(tao::ral::kRalSendOutput, "cpu", ral_base_cuda_send_output_0d<T>);

RAL_REGISTER_IO_FUNC_0D(float);
RAL_REGISTER_IO_FUNC_0D(bfloat16);
RAL_REGISTER_IO_FUNC_0D(double);
RAL_REGISTER_IO_FUNC_0D(int8_t);
RAL_REGISTER_IO_FUNC_0D(int32_t);
Expand Down Expand Up @@ -306,5 +307,13 @@ RAL_REGISTER_IO_FUNC(Eigen::half, 5);
RAL_REGISTER_IO_FUNC(Eigen::half, 6);
RAL_REGISTER_IO_FUNC(Eigen::half, 7);
RAL_REGISTER_IO_FUNC(Eigen::half, 8);
RAL_REGISTER_IO_FUNC(bfloat16, 1);
RAL_REGISTER_IO_FUNC(bfloat16, 2);
RAL_REGISTER_IO_FUNC(bfloat16, 3);
RAL_REGISTER_IO_FUNC(bfloat16, 4);
RAL_REGISTER_IO_FUNC(bfloat16, 5);
RAL_REGISTER_IO_FUNC(bfloat16, 6);
RAL_REGISTER_IO_FUNC(bfloat16, 7);
RAL_REGISTER_IO_FUNC(bfloat16, 8);
} // namespace ral
} // namespace tao
1 change: 1 addition & 0 deletions tao_compiler/mlir/ral/context/base/cpu/cpu_context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ RAL_REGISTER_BITCAST_FUNC_0D(double);
RAL_REGISTER_BITCAST_FUNC_0D(int32_t);
RAL_REGISTER_BITCAST_FUNC_0D(int64_t);
RAL_REGISTER_BITCAST_FUNC_0D(bool);
RAL_REGISTER_BITCAST_FUNC_0D(bfloat16);
RAL_REGISTER_BITCAST_FUNC(float, 1);
RAL_REGISTER_BITCAST_FUNC(float, 2);
RAL_REGISTER_BITCAST_FUNC(float, 3);
Expand Down
10 changes: 10 additions & 0 deletions tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ void ral_base_cuda_d2h(ExecutionContext* ctx, void* stream_handle,
TAO_RAL_API(tao::ral::kRalBitcast, "gpu", ral_base_cuda_bitcast_0d<T, 8, 0>);

RAL_REGISTER_BITCAST_FUNC_0D(Eigen::half);
RAL_REGISTER_BITCAST_FUNC_0D(Eigen::bfloat16);
RAL_REGISTER_BITCAST_FUNC_0D(float);
RAL_REGISTER_BITCAST_FUNC_0D(double);
RAL_REGISTER_BITCAST_FUNC_0D(int32_t);
Expand All @@ -684,6 +685,14 @@ RAL_REGISTER_BITCAST_FUNC(Eigen::half, 5);
RAL_REGISTER_BITCAST_FUNC(Eigen::half, 6);
RAL_REGISTER_BITCAST_FUNC(Eigen::half, 7);
RAL_REGISTER_BITCAST_FUNC(Eigen::half, 8);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 1);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 2);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 3);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 4);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 5);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 6);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 7);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 8);
RAL_REGISTER_BITCAST_FUNC(float, 1);
RAL_REGISTER_BITCAST_FUNC(float, 2);
RAL_REGISTER_BITCAST_FUNC(float, 3);
Expand Down Expand Up @@ -745,5 +754,6 @@ TAO_RAL_API(tao::ral::gpu::kRalGpuSyncOnStream, "gpu",
TAO_RAL_API(tao::ral::gpu::kRalGpuMemset, "gpu", ral_base_cuda_memset);

} // namespace gpu
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
} // namespace ral
} // namespace tao
18 changes: 15 additions & 3 deletions tao_compiler/mlir/ral/context/stream_executor_based_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ inline se::blas::ComputationType NativeTypeToBlasType<double>() {
return se::blas::ComputationType::kF64;
}

template <>
inline se::blas::ComputationType NativeTypeToBlasType<Eigen::bfloat16>() {
return se::blas::ComputationType::kBF16AsF32;
}

// The template was introduced, because not all instantiation of
// DoGemmWithAlgorithm template arguments was support by ThenBlasGemv.
template <typename InT, typename OutT, typename AlphaBeta>
Expand Down Expand Up @@ -293,7 +298,8 @@ se::blas::AlgorithmType tuningGemm(se::Stream* stream,
se::blas::ProfileResult profile_result;
DoGemmWithAlgorithm<InT, OutT, AlphaBeta>(
/*batch_size*/ 1, lhs_matrix, rhs_matrix, output_matrix,
/*alpha*/ 1., /*beta*/ 0., stream, algorithm, &profile_result);
/*alpha*/ AlphaBeta(1.0), /*beta*/ AlphaBeta(0.0), stream, algorithm,
&profile_result);

if (!profile_result.is_valid()) {
TAO_VLOG(1) << "algo: " << algorithm << " is invalid.";
Expand Down Expand Up @@ -408,8 +414,8 @@ void ral_gemm(ExecutionContext* ctx, void* stream_handle, MemRefType<InT, 2> A,

auto s = DoGemmWithAlgorithm<InT, OutT, E>(
/*batch_size*/ 1, lhs_matrix, rhs_matrix, output_matrix,
/*alpha*/ E(1.),
/*beta*/ E(0.), stream, best_algo_wrapper,
/*alpha*/ static_cast<E>(1.0),
/*beta*/ static_cast<E>(0.0), stream, best_algo_wrapper,
/*output_profile_result=*/nullptr);

if (!s) {
Expand Down Expand Up @@ -1952,10 +1958,14 @@ void ral_qconv(ExecutionContext* ctx, void* stream_handle,
} // namespace gpu

// gemm ops
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<float, float>);
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<double, double, double>);
TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_gemm<Eigen::half, Eigen::half>);
TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_gemm<Eigen::bfloat16, Eigen::bfloat16>);

TAO_RAL_API("ral_qgemm", "gpu", gpu::se_impl::ral_qgemm);
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm<float, float, 3>);
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm<float, float, 4>);
Expand All @@ -1965,6 +1975,8 @@ TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_batch_gemm<double, double, 4, double>);
TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_batch_gemm<Eigen::half, Eigen::half, 3>);
TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_batch_gemm<Eigen::bfloat16, Eigen::bfloat16, 3>);
TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_batch_gemm<Eigen::half, Eigen::half, 4>);
#ifdef BLAZE_OPT
Expand Down
10 changes: 10 additions & 0 deletions tao_compiler/mlir/ral/context/tensorflow/tf_context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ void ral_tf_send_output_0d(ExecutionContext* ctx, int64_t output_idx,
TAO_RAL_API(::tao::ral::kRalBitcast, "cpu", ral_tf_bitcast_0d<T, 8, 0>);

RAL_REGISTER_IO_FUNC_0D(float);
RAL_REGISTER_IO_FUNC_0D(bfloat16);
RAL_REGISTER_IO_FUNC_0D(double);
RAL_REGISTER_IO_FUNC_0D(Eigen::half);
RAL_REGISTER_IO_FUNC_0D(int8_t);
Expand Down Expand Up @@ -980,13 +981,22 @@ RAL_REGISTER_IO_FUNC(bool, 5);
RAL_REGISTER_IO_FUNC(bool, 6);
RAL_REGISTER_IO_FUNC(bool, 7);
RAL_REGISTER_IO_FUNC(bool, 8);
RAL_REGISTER_IO_FUNC(bfloat16, 1);
RAL_REGISTER_IO_FUNC(bfloat16, 2);
RAL_REGISTER_IO_FUNC(bfloat16, 3);
RAL_REGISTER_IO_FUNC(bfloat16, 4);
RAL_REGISTER_IO_FUNC(bfloat16, 5);
RAL_REGISTER_IO_FUNC(bfloat16, 6);
RAL_REGISTER_IO_FUNC(bfloat16, 7);
RAL_REGISTER_IO_FUNC(bfloat16, 8);

} // namespace tensorflow

namespace tao {
namespace ral {

DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16");
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");

TAO_RAL_API(::tao::ral::cpu::kRalCpuAlloc, "cpu", tensorflow::ral_tf_cpu_alloc);
TAO_RAL_API(::tao::ral::cpu::kRalCpuAllocPersistent, "cpu",
Expand Down

0 comments on commit 3d293a4

Please sign in to comment.