Skip to content

Commit

Permalink
support Llama bf16 amp training (#1293)
Browse files Browse the repository at this point in the history
support llama bf16 amp training
  • Loading branch information
Yancey1989 authored May 21, 2024
1 parent 52cb669 commit bc0994a
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 25 deletions.
8 changes: 8 additions & 0 deletions tao_compiler/ci_build/platforms/tao/gpu/env.conf.cuda11_8
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
TF_NEED_CUDA=1
TF_CUDA_CLANG=0
TF_CUDA_VERSION=11.8
TF_CUDNN_VERSION=8
TF_CUDA_COMPUTE_CAPABILITIES="6.0,6.1,7.0,7.5,8.0,8.6"
TF_NEED_TENSORRT=0
TF_NEED_ROCM=0
TF_SET_ANDROID_WORKSPACE=0
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ void LaunchTransposeKernel(cudaStream_t stream, T* input,
template void LaunchTransposeKernel<float>(cudaStream_t stream, float* input,
std::vector<int64_t> input_dims,
float* output);

template void LaunchTransposeKernel<Eigen::half>(
cudaStream_t stream, Eigen::half* input, std::vector<int64_t> input_dims,
Eigen::half* output);
template void LaunchTransposeKernel<Eigen::bfloat16>(
cudaStream_t stream, Eigen::bfloat16* input,
std::vector<int64_t> input_dims, Eigen::bfloat16* output);
#endif

} // namespace ral
Expand Down
4 changes: 3 additions & 1 deletion tao_compiler/mlir/custom_ops/transpose_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ void ral_transpose(ExecutionContext* ctx, void* stream_handle,

LaunchTransposeKernel<T>(stream, d_in, input_dims, d_out);
}

DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<float, 2>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<float, 3>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::half, 2>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::half, 3>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::bfloat16, 2>);
TAO_RAL_API("ral_transpose", "gpu", ral_transpose<Eigen::bfloat16, 3>);
#endif

} // namespace ral
Expand Down
4 changes: 3 additions & 1 deletion tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,12 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
// optimization. Then this pass will be enabled by default.
pm.addNestedPass<FuncOp>(disc_ral::createForLoopUnrollInterleavePass());
}
pm.addNestedPass<FuncOp>(arith::createArithExpandOpsPass());
// Origin: https://reviews.llvm.org/D147585
// Should be removed after rebasing to the latest llvm head
pm.addNestedPass<FuncOp>(disc_ral::createDiscBF16ExpansionPass());
mlir::arith::ArithExpandOpsOptions arith_option;
arith_option.includeBf16 = true;
pm.addNestedPass<FuncOp>(arith::createArithExpandOpsPass(arith_option));
pm.addNestedPass<FuncOp>(mlir::memref::createFoldMemRefAliasOpsPass());

// Flatten multi dim memref accesses to its 1D format to enable more
Expand Down
30 changes: 16 additions & 14 deletions tao_compiler/mlir/disc/transforms/disc_bf16_expansion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Type resultETy = getElementTypeOrSelf(resultTy);

if (!operandETy.isBF16() || !resultETy.isF32()) {
return failure();
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
}

Type i16Ty = b.getI16Type();
Expand Down Expand Up @@ -98,21 +98,9 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type resultETy = getElementTypeOrSelf(resultTy);

if (!operandETy.isF32() || !resultETy.isBF16()) {
return failure();
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
}

#if defined(TAO_AARCH64)
if (isBFCVTEnabled()) {
auto intrinsicName =
StringAttr::get(rewriter.getContext(), "llvm.aarch64.neon.bfcvt");
SmallVector<Value, 2> args;
args.push_back(operand);
rewriter.replaceOpWithNewOp<LLVM::CallIntrinsicOp>(op, resultETy,
intrinsicName, args);
return success();
}
#endif

Type i1Ty = b.getI1Type();
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Expand All @@ -125,7 +113,21 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}

Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);

// fast rouding algorithm to trunct fp32 to bf16:
// uint32_t lsb = (input >> 16) & 1;
// uint32_t rounding_bias = 0x7fff + lsb;
// input += rounding_bias;
// output.value = static_cast<uint16_t>(input >> 16);
// ref:
// htps://hhhhhojeihsu.github.io/tensorflow_1.8_woboq/tensorflow_1.8_xla/tensorflow/tensorflow/core/lib/bfloat16/bfloat16.h.html#196
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
Value lsb = b.create<arith::ShRUIOp>(bitcast, c16);
lsb = b.create<arith::AndIOp>(lsb, c1);
Value rouding_bias = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
rouding_bias = b.create<arith::AddIOp>(rouding_bias, lsb);
bitcast = b.create<arith::AddIOp>(bitcast, rouding_bias);
Value shr = b.create<arith::ShRUIOp>(bitcast, c16);
Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
Value result = b.create<arith::BitcastOp>(resultTy, trunc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License.
lmhlo::AbsOp, lmhlo::CeilOp, lmhlo::FloorOp, lmhlo::ConvertOp, lmhlo::CosineOp,
lmhlo::ExpOp, lmhlo::LogOp, lmhlo::NegOp, lmhlo::RsqrtOp, lmhlo::SqrtOp,
lmhlo::SignOp, lmhlo::TanhOp, lmhlo::LogisticOp, lmhlo::Log1pOp,
lmhlo::SineOp, lmhlo::RoundOp, lmhlo::RoundNearestEvenOp,
lmhlo::SineOp, lmhlo::RoundOp, lmhlo::RoundNearestEvenOp, lmhlo::BitcastConvertOp,

// Binary Elementwise Ops
lmhlo::AddOp, lmhlo::DivOp, lmhlo::MaxOp, lmhlo::MinOp, lmhlo::MulOp,
Expand Down
13 changes: 12 additions & 1 deletion tao_compiler/mlir/disc/transforms/disc_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,18 @@ LogicalResult getTypeEncoding(MLIRContext* ctx, Type t, StrT& out) {
out.append(Twine("i").concat(Twine(int_type.getWidth())).str());
}
} else if (auto fp_type = t.dyn_cast<FloatType>()) {
out.append(Twine("f").concat(Twine(fp_type.getWidth())).str());
if (fp_type.isF16()) {
out.append("f16");
} else if (fp_type.isBF16()) {
out.append("bf16");
} else if (fp_type.isF32()) {
out.append("f32");
} else if (fp_type.isF64()) {
out.append("f64");
} else {
return failure();
}
// out.append(Twine("f").concat(Twine(fp_type.getWidth())).str());
} else if (auto ctx_type = t.dyn_cast<RalExecutionContextType>() ||
t == llvm_i8ptr_type || t == llvm_ptr_type) {
out.append("pvoid");
Expand Down
16 changes: 13 additions & 3 deletions tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1231,10 +1231,11 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,

Value zero_element;
if (result_elem_type.isF16() || result_elem_type.isF32() ||
result_elem_type.isF64()) {
result_elem_type.isF64() || result_elem_type.isBF16()) {
auto float_result_elem_type = result_elem_type.cast<FloatType>();
zero_element = b->create<arith::ConstantFloatOp>(
loc, APFloat::getZero(float_result_elem_type.getFloatSemantics()),
loc,
APFloat::getZero(float_result_elem_type.getFloatSemantics(), false),
float_result_elem_type);
} else if (result_elem_type.isSignlessInteger() ||
result_elem_type.isSignedInteger() ||
Expand Down Expand Up @@ -1304,7 +1305,16 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,

b->setInsertionPointToEnd(&if_inbound_ops[i].getElseRegion().front());
if (i == num_input_operands - 1) {
b->create<scf::YieldOp>(loc, zero_element); // expect never used
input_index[axis] = b->create<arith::SubIOp>(loc, out_idx, low_bound);
auto operand_memref = op.getOperand(i);
auto ret_value =
check_cache ? createLoadOrUseCachedValue(
loc, b, op.getOperation(), operand_memref,
input_index, b->saveInsertionPoint(), lower_config)
: createMaySpecificLoad(*b, loc, op.getOperation(),
operand_memref, input_index,
lower_config);
b->create<scf::YieldOp>(loc, ret_value);
} else {
b->create<scf::YieldOp>(loc, if_inbound_ops[i + 1].getResults());
}
Expand Down
9 changes: 9 additions & 0 deletions tao_compiler/mlir/ral/context/base/base_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ void ral_base_cuda_send_output_0d(ExecutionContext* ctx, int64_t output_idx,
TAO_RAL_API(tao::ral::kRalSendOutput, "cpu", ral_base_cuda_send_output_0d<T>);

RAL_REGISTER_IO_FUNC_0D(float);
RAL_REGISTER_IO_FUNC_0D(bfloat16);
RAL_REGISTER_IO_FUNC_0D(double);
RAL_REGISTER_IO_FUNC_0D(int8_t);
RAL_REGISTER_IO_FUNC_0D(int32_t);
Expand Down Expand Up @@ -306,5 +307,13 @@ RAL_REGISTER_IO_FUNC(Eigen::half, 5);
RAL_REGISTER_IO_FUNC(Eigen::half, 6);
RAL_REGISTER_IO_FUNC(Eigen::half, 7);
RAL_REGISTER_IO_FUNC(Eigen::half, 8);
RAL_REGISTER_IO_FUNC(bfloat16, 1);
RAL_REGISTER_IO_FUNC(bfloat16, 2);
RAL_REGISTER_IO_FUNC(bfloat16, 3);
RAL_REGISTER_IO_FUNC(bfloat16, 4);
RAL_REGISTER_IO_FUNC(bfloat16, 5);
RAL_REGISTER_IO_FUNC(bfloat16, 6);
RAL_REGISTER_IO_FUNC(bfloat16, 7);
RAL_REGISTER_IO_FUNC(bfloat16, 8);
} // namespace ral
} // namespace tao
1 change: 1 addition & 0 deletions tao_compiler/mlir/ral/context/base/cpu/cpu_context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ RAL_REGISTER_BITCAST_FUNC_0D(double);
RAL_REGISTER_BITCAST_FUNC_0D(int32_t);
RAL_REGISTER_BITCAST_FUNC_0D(int64_t);
RAL_REGISTER_BITCAST_FUNC_0D(bool);
RAL_REGISTER_BITCAST_FUNC_0D(bfloat16);
RAL_REGISTER_BITCAST_FUNC(float, 1);
RAL_REGISTER_BITCAST_FUNC(float, 2);
RAL_REGISTER_BITCAST_FUNC(float, 3);
Expand Down
10 changes: 10 additions & 0 deletions tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ void ral_base_cuda_d2h(ExecutionContext* ctx, void* stream_handle,
TAO_RAL_API(tao::ral::kRalBitcast, "gpu", ral_base_cuda_bitcast_0d<T, 8, 0>);

RAL_REGISTER_BITCAST_FUNC_0D(Eigen::half);
RAL_REGISTER_BITCAST_FUNC_0D(Eigen::bfloat16);
RAL_REGISTER_BITCAST_FUNC_0D(float);
RAL_REGISTER_BITCAST_FUNC_0D(double);
RAL_REGISTER_BITCAST_FUNC_0D(int32_t);
Expand All @@ -684,6 +685,14 @@ RAL_REGISTER_BITCAST_FUNC(Eigen::half, 5);
RAL_REGISTER_BITCAST_FUNC(Eigen::half, 6);
RAL_REGISTER_BITCAST_FUNC(Eigen::half, 7);
RAL_REGISTER_BITCAST_FUNC(Eigen::half, 8);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 1);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 2);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 3);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 4);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 5);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 6);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 7);
RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 8);
RAL_REGISTER_BITCAST_FUNC(float, 1);
RAL_REGISTER_BITCAST_FUNC(float, 2);
RAL_REGISTER_BITCAST_FUNC(float, 3);
Expand Down Expand Up @@ -745,5 +754,6 @@ TAO_RAL_API(tao::ral::gpu::kRalGpuSyncOnStream, "gpu",
TAO_RAL_API(tao::ral::gpu::kRalGpuMemset, "gpu", ral_base_cuda_memset);

} // namespace gpu
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
} // namespace ral
} // namespace tao
21 changes: 18 additions & 3 deletions tao_compiler/mlir/ral/context/stream_executor_based_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ inline se::blas::ComputationType NativeTypeToBlasType<double>() {
return se::blas::ComputationType::kF64;
}

template <>
inline se::blas::ComputationType NativeTypeToBlasType<Eigen::bfloat16>() {
return se::blas::ComputationType::kBF16AsF32;
}

// The template was introduced, because not all instantiation of
// DoGemmWithAlgorithm template arguments was support by ThenBlasGemv.
template <typename InT, typename OutT, typename AlphaBeta>
Expand Down Expand Up @@ -293,7 +298,8 @@ se::blas::AlgorithmType tuningGemm(se::Stream* stream,
se::blas::ProfileResult profile_result;
DoGemmWithAlgorithm<InT, OutT, AlphaBeta>(
/*batch_size*/ 1, lhs_matrix, rhs_matrix, output_matrix,
/*alpha*/ 1., /*beta*/ 0., stream, algorithm, &profile_result);
/*alpha*/ AlphaBeta(1.0), /*beta*/ AlphaBeta(0.0), stream, algorithm,
&profile_result);

if (!profile_result.is_valid()) {
TAO_VLOG(1) << "algo: " << algorithm << " is invalid.";
Expand Down Expand Up @@ -1952,10 +1958,18 @@ void ral_qconv(ExecutionContext* ctx, void* stream_handle,
} // namespace gpu

// gemm ops
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<float, float>);
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<double, double, double>);
#ifndef DISC_BUILD_FROM_TF_BRIDGE
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");
TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_gemm<Eigen::half, Eigen::half>);
TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_gemm<Eigen::bfloat16, Eigen::bfloat16, float>);
TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_batch_gemm<Eigen::bfloat16, Eigen::bfloat16, 3>);
#endif
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<float, float>);
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm<double, double, double>);

TAO_RAL_API("ral_qgemm", "gpu", gpu::se_impl::ral_qgemm);
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm<float, float, 3>);
TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm<float, float, 4>);
Expand All @@ -1965,6 +1979,7 @@ TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_batch_gemm<double, double, 4, double>);
TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_batch_gemm<Eigen::half, Eigen::half, 3>);

TAO_RAL_API("ral_gemm", "gpu",
gpu::se_impl::ral_batch_gemm<Eigen::half, Eigen::half, 4>);
#ifdef BLAZE_OPT
Expand Down
10 changes: 10 additions & 0 deletions tao_compiler/mlir/ral/context/tensorflow/tf_context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ void ral_tf_send_output_0d(ExecutionContext* ctx, int64_t output_idx,
TAO_RAL_API(::tao::ral::kRalBitcast, "cpu", ral_tf_bitcast_0d<T, 8, 0>);

RAL_REGISTER_IO_FUNC_0D(float);
RAL_REGISTER_IO_FUNC_0D(bfloat16);
RAL_REGISTER_IO_FUNC_0D(double);
RAL_REGISTER_IO_FUNC_0D(Eigen::half);
RAL_REGISTER_IO_FUNC_0D(int8_t);
Expand Down Expand Up @@ -980,13 +981,22 @@ RAL_REGISTER_IO_FUNC(bool, 5);
RAL_REGISTER_IO_FUNC(bool, 6);
RAL_REGISTER_IO_FUNC(bool, 7);
RAL_REGISTER_IO_FUNC(bool, 8);
RAL_REGISTER_IO_FUNC(bfloat16, 1);
RAL_REGISTER_IO_FUNC(bfloat16, 2);
RAL_REGISTER_IO_FUNC(bfloat16, 3);
RAL_REGISTER_IO_FUNC(bfloat16, 4);
RAL_REGISTER_IO_FUNC(bfloat16, 5);
RAL_REGISTER_IO_FUNC(bfloat16, 6);
RAL_REGISTER_IO_FUNC(bfloat16, 7);
RAL_REGISTER_IO_FUNC(bfloat16, 8);

} // namespace tensorflow

namespace tao {
namespace ral {

DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16");
DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16");

TAO_RAL_API(::tao::ral::cpu::kRalCpuAlloc, "cpu", tensorflow::ral_tf_cpu_alloc);
TAO_RAL_API(::tao::ral::cpu::kRalCpuAllocPersistent, "cpu",
Expand Down

0 comments on commit bc0994a

Please sign in to comment.