From bc0994aba19be3c15c392b648dcfa7509ee5d95e Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Tue, 21 May 2024 15:48:58 +0800 Subject: [PATCH] support Llama bf16 amp training (#1293) support llama bf16 amp training --- .../platforms/tao/gpu/env.conf.cuda11_8 | 8 +++++ .../custom_library/transpose_gpu.cu.cc | 4 ++- .../mlir/custom_ops/transpose_impl.cc | 4 ++- tao_compiler/mlir/disc/disc_compiler.cc | 4 ++- .../disc/transforms/disc_bf16_expansion.cc | 30 ++++++++++--------- .../disc/transforms/disc_supported_list.h.inc | 2 +- .../mlir/disc/transforms/disc_to_llvm.cc | 13 +++++++- .../disc/transforms/lhlo_elemental_utils.cc | 16 ++++++++-- .../mlir/ral/context/base/base_context.cc | 9 ++++++ .../ral/context/base/cpu/cpu_context_impl.cc | 1 + .../context/base/cuda/cuda_context_impl.cc | 10 +++++++ .../ral/context/stream_executor_based_impl.cc | 21 +++++++++++-- .../ral/context/tensorflow/tf_context_impl.cc | 10 +++++++ 13 files changed, 107 insertions(+), 25 deletions(-) create mode 100644 tao_compiler/ci_build/platforms/tao/gpu/env.conf.cuda11_8 diff --git a/tao_compiler/ci_build/platforms/tao/gpu/env.conf.cuda11_8 b/tao_compiler/ci_build/platforms/tao/gpu/env.conf.cuda11_8 new file mode 100644 index 00000000000..487d57794f5 --- /dev/null +++ b/tao_compiler/ci_build/platforms/tao/gpu/env.conf.cuda11_8 @@ -0,0 +1,8 @@ +TF_NEED_CUDA=1 +TF_CUDA_CLANG=0 +TF_CUDA_VERSION=11.8 +TF_CUDNN_VERSION=8 +TF_CUDA_COMPUTE_CAPABILITIES="6.0,6.1,7.0,7.5,8.0,8.6" +TF_NEED_TENSORRT=0 +TF_NEED_ROCM=0 +TF_SET_ANDROID_WORKSPACE=0 diff --git a/tao_compiler/mlir/custom_ops/custom_library/transpose_gpu.cu.cc b/tao_compiler/mlir/custom_ops/custom_library/transpose_gpu.cu.cc index d4ab9742110..1ac3a5254a8 100644 --- a/tao_compiler/mlir/custom_ops/custom_library/transpose_gpu.cu.cc +++ b/tao_compiler/mlir/custom_ops/custom_library/transpose_gpu.cu.cc @@ -57,10 +57,12 @@ void LaunchTransposeKernel(cudaStream_t stream, T* input, template void LaunchTransposeKernel(cudaStream_t stream, float* input, std::vector input_dims, float* output); - template void LaunchTransposeKernel( cudaStream_t stream, Eigen::half* input, std::vector input_dims, Eigen::half* output); +template void LaunchTransposeKernel( + cudaStream_t stream, Eigen::bfloat16* input, + std::vector input_dims, Eigen::bfloat16* output); #endif } // namespace ral diff --git a/tao_compiler/mlir/custom_ops/transpose_impl.cc b/tao_compiler/mlir/custom_ops/transpose_impl.cc index 1c4fde328b4..fb275a94f73 100644 --- a/tao_compiler/mlir/custom_ops/transpose_impl.cc +++ b/tao_compiler/mlir/custom_ops/transpose_impl.cc @@ -61,11 +61,13 @@ void ral_transpose(ExecutionContext* ctx, void* stream_handle, LaunchTransposeKernel(stream, d_in, input_dims, d_out); } - +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16"); TAO_RAL_API("ral_transpose", "gpu", ral_transpose); TAO_RAL_API("ral_transpose", "gpu", ral_transpose); TAO_RAL_API("ral_transpose", "gpu", ral_transpose); TAO_RAL_API("ral_transpose", "gpu", ral_transpose); +TAO_RAL_API("ral_transpose", "gpu", ral_transpose); +TAO_RAL_API("ral_transpose", "gpu", ral_transpose); #endif } // namespace ral diff --git a/tao_compiler/mlir/disc/disc_compiler.cc b/tao_compiler/mlir/disc/disc_compiler.cc index b55b07366fa..98df6e118d7 100644 --- a/tao_compiler/mlir/disc/disc_compiler.cc +++ b/tao_compiler/mlir/disc/disc_compiler.cc @@ -563,10 +563,12 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) { // optimization. Then this pass will be enabled by default. pm.addNestedPass(disc_ral::createForLoopUnrollInterleavePass()); } - pm.addNestedPass(arith::createArithExpandOpsPass()); // Origin: https://reviews.llvm.org/D147585 // Should be removed after rebasing to the latest llvm head pm.addNestedPass(disc_ral::createDiscBF16ExpansionPass()); + mlir::arith::ArithExpandOpsOptions arith_option; + arith_option.includeBf16 = true; + pm.addNestedPass(arith::createArithExpandOpsPass(arith_option)); pm.addNestedPass(mlir::memref::createFoldMemRefAliasOpsPass()); // Flatten multi dim memref accesses to its 1D format to enable more diff --git a/tao_compiler/mlir/disc/transforms/disc_bf16_expansion.cc b/tao_compiler/mlir/disc/transforms/disc_bf16_expansion.cc index 294d0a03ca8..41019f8f1a3 100644 --- a/tao_compiler/mlir/disc/transforms/disc_bf16_expansion.cc +++ b/tao_compiler/mlir/disc/transforms/disc_bf16_expansion.cc @@ -64,7 +64,7 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern { Type resultETy = getElementTypeOrSelf(resultTy); if (!operandETy.isBF16() || !resultETy.isF32()) { - return failure(); + return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32."); } Type i16Ty = b.getI16Type(); @@ -98,21 +98,9 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { Type resultETy = getElementTypeOrSelf(resultTy); if (!operandETy.isF32() || !resultETy.isBF16()) { - return failure(); + return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16."); } -#if defined(TAO_AARCH64) - if (isBFCVTEnabled()) { - auto intrinsicName = - StringAttr::get(rewriter.getContext(), "llvm.aarch64.neon.bfcvt"); - SmallVector args; - args.push_back(operand); - rewriter.replaceOpWithNewOp(op, resultETy, - intrinsicName, args); - return success(); - } -#endif - Type i1Ty = b.getI1Type(); Type i16Ty = b.getI16Type(); Type i32Ty = b.getI32Type(); @@ -125,7 +113,21 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { } Value bitcast = b.create(i32Ty, operand); + + // fast rouding algorithm to trunct fp32 to bf16: + // uint32_t lsb = (input >> 16) & 1; + // uint32_t rounding_bias = 0x7fff + lsb; + // input += rounding_bias; + // output.value = static_cast(input >> 16); + // ref: + // htps://hhhhhojeihsu.github.io/tensorflow_1.8_woboq/tensorflow_1.8_xla/tensorflow/tensorflow/core/lib/bfloat16/bfloat16.h.html#196 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); + Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter); + Value lsb = b.create(bitcast, c16); + lsb = b.create(lsb, c1); + Value rouding_bias = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter); + rouding_bias = b.create(rouding_bias, lsb); + bitcast = b.create(bitcast, rouding_bias); Value shr = b.create(bitcast, c16); Value trunc = b.create(i16Ty, shr); Value result = b.create(resultTy, trunc); diff --git a/tao_compiler/mlir/disc/transforms/disc_supported_list.h.inc b/tao_compiler/mlir/disc/transforms/disc_supported_list.h.inc index 91c49dee13e..fdfb08e7384 100644 --- a/tao_compiler/mlir/disc/transforms/disc_supported_list.h.inc +++ b/tao_compiler/mlir/disc/transforms/disc_supported_list.h.inc @@ -22,7 +22,7 @@ limitations under the License. lmhlo::AbsOp, lmhlo::CeilOp, lmhlo::FloorOp, lmhlo::ConvertOp, lmhlo::CosineOp, lmhlo::ExpOp, lmhlo::LogOp, lmhlo::NegOp, lmhlo::RsqrtOp, lmhlo::SqrtOp, lmhlo::SignOp, lmhlo::TanhOp, lmhlo::LogisticOp, lmhlo::Log1pOp, -lmhlo::SineOp, lmhlo::RoundOp, lmhlo::RoundNearestEvenOp, +lmhlo::SineOp, lmhlo::RoundOp, lmhlo::RoundNearestEvenOp, lmhlo::BitcastConvertOp, // Binary Elementwise Ops lmhlo::AddOp, lmhlo::DivOp, lmhlo::MaxOp, lmhlo::MinOp, lmhlo::MulOp, diff --git a/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc b/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc index ed0128097bc..5e4e84c400b 100644 --- a/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc +++ b/tao_compiler/mlir/disc/transforms/disc_to_llvm.cc @@ -87,7 +87,18 @@ LogicalResult getTypeEncoding(MLIRContext* ctx, Type t, StrT& out) { out.append(Twine("i").concat(Twine(int_type.getWidth())).str()); } } else if (auto fp_type = t.dyn_cast()) { - out.append(Twine("f").concat(Twine(fp_type.getWidth())).str()); + if (fp_type.isF16()) { + out.append("f16"); + } else if (fp_type.isBF16()) { + out.append("bf16"); + } else if (fp_type.isF32()) { + out.append("f32"); + } else if (fp_type.isF64()) { + out.append("f64"); + } else { + return failure(); + } + // out.append(Twine("f").concat(Twine(fp_type.getWidth())).str()); } else if (auto ctx_type = t.dyn_cast() || t == llvm_i8ptr_type || t == llvm_ptr_type) { out.append("pvoid"); diff --git a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc index 745973be732..fcf33a6e4e5 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc @@ -1231,10 +1231,11 @@ Value elementalLower(OpBuilder* b, Location loc, Value zero_element; if (result_elem_type.isF16() || result_elem_type.isF32() || - result_elem_type.isF64()) { + result_elem_type.isF64() || result_elem_type.isBF16()) { auto float_result_elem_type = result_elem_type.cast(); zero_element = b->create( - loc, APFloat::getZero(float_result_elem_type.getFloatSemantics()), + loc, + APFloat::getZero(float_result_elem_type.getFloatSemantics(), false), float_result_elem_type); } else if (result_elem_type.isSignlessInteger() || result_elem_type.isSignedInteger() || @@ -1304,7 +1305,16 @@ Value elementalLower(OpBuilder* b, Location loc, b->setInsertionPointToEnd(&if_inbound_ops[i].getElseRegion().front()); if (i == num_input_operands - 1) { - b->create(loc, zero_element); // expect never used + input_index[axis] = b->create(loc, out_idx, low_bound); + auto operand_memref = op.getOperand(i); + auto ret_value = + check_cache ? createLoadOrUseCachedValue( + loc, b, op.getOperation(), operand_memref, + input_index, b->saveInsertionPoint(), lower_config) + : createMaySpecificLoad(*b, loc, op.getOperation(), + operand_memref, input_index, + lower_config); + b->create(loc, ret_value); } else { b->create(loc, if_inbound_ops[i + 1].getResults()); } diff --git a/tao_compiler/mlir/ral/context/base/base_context.cc b/tao_compiler/mlir/ral/context/base/base_context.cc index 734cd663878..8e7c0df3c25 100644 --- a/tao_compiler/mlir/ral/context/base/base_context.cc +++ b/tao_compiler/mlir/ral/context/base/base_context.cc @@ -235,6 +235,7 @@ void ral_base_cuda_send_output_0d(ExecutionContext* ctx, int64_t output_idx, TAO_RAL_API(tao::ral::kRalSendOutput, "cpu", ral_base_cuda_send_output_0d); RAL_REGISTER_IO_FUNC_0D(float); +RAL_REGISTER_IO_FUNC_0D(bfloat16); RAL_REGISTER_IO_FUNC_0D(double); RAL_REGISTER_IO_FUNC_0D(int8_t); RAL_REGISTER_IO_FUNC_0D(int32_t); @@ -306,5 +307,13 @@ RAL_REGISTER_IO_FUNC(Eigen::half, 5); RAL_REGISTER_IO_FUNC(Eigen::half, 6); RAL_REGISTER_IO_FUNC(Eigen::half, 7); RAL_REGISTER_IO_FUNC(Eigen::half, 8); +RAL_REGISTER_IO_FUNC(bfloat16, 1); +RAL_REGISTER_IO_FUNC(bfloat16, 2); +RAL_REGISTER_IO_FUNC(bfloat16, 3); +RAL_REGISTER_IO_FUNC(bfloat16, 4); +RAL_REGISTER_IO_FUNC(bfloat16, 5); +RAL_REGISTER_IO_FUNC(bfloat16, 6); +RAL_REGISTER_IO_FUNC(bfloat16, 7); +RAL_REGISTER_IO_FUNC(bfloat16, 8); } // namespace ral } // namespace tao diff --git a/tao_compiler/mlir/ral/context/base/cpu/cpu_context_impl.cc b/tao_compiler/mlir/ral/context/base/cpu/cpu_context_impl.cc index 68f540667c7..0af0bfced76 100644 --- a/tao_compiler/mlir/ral/context/base/cpu/cpu_context_impl.cc +++ b/tao_compiler/mlir/ral/context/base/cpu/cpu_context_impl.cc @@ -370,6 +370,7 @@ RAL_REGISTER_BITCAST_FUNC_0D(double); RAL_REGISTER_BITCAST_FUNC_0D(int32_t); RAL_REGISTER_BITCAST_FUNC_0D(int64_t); RAL_REGISTER_BITCAST_FUNC_0D(bool); +RAL_REGISTER_BITCAST_FUNC_0D(bfloat16); RAL_REGISTER_BITCAST_FUNC(float, 1); RAL_REGISTER_BITCAST_FUNC(float, 2); RAL_REGISTER_BITCAST_FUNC(float, 3); diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc index 0af72f978f2..b350d8f78ce 100644 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc @@ -671,6 +671,7 @@ void ral_base_cuda_d2h(ExecutionContext* ctx, void* stream_handle, TAO_RAL_API(tao::ral::kRalBitcast, "gpu", ral_base_cuda_bitcast_0d); RAL_REGISTER_BITCAST_FUNC_0D(Eigen::half); +RAL_REGISTER_BITCAST_FUNC_0D(Eigen::bfloat16); RAL_REGISTER_BITCAST_FUNC_0D(float); RAL_REGISTER_BITCAST_FUNC_0D(double); RAL_REGISTER_BITCAST_FUNC_0D(int32_t); @@ -684,6 +685,14 @@ RAL_REGISTER_BITCAST_FUNC(Eigen::half, 5); RAL_REGISTER_BITCAST_FUNC(Eigen::half, 6); RAL_REGISTER_BITCAST_FUNC(Eigen::half, 7); RAL_REGISTER_BITCAST_FUNC(Eigen::half, 8); +RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 1); +RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 2); +RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 3); +RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 4); +RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 5); +RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 6); +RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 7); +RAL_REGISTER_BITCAST_FUNC(Eigen::bfloat16, 8); RAL_REGISTER_BITCAST_FUNC(float, 1); RAL_REGISTER_BITCAST_FUNC(float, 2); RAL_REGISTER_BITCAST_FUNC(float, 3); @@ -745,5 +754,6 @@ TAO_RAL_API(tao::ral::gpu::kRalGpuSyncOnStream, "gpu", TAO_RAL_API(tao::ral::gpu::kRalGpuMemset, "gpu", ral_base_cuda_memset); } // namespace gpu +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16"); } // namespace ral } // namespace tao diff --git a/tao_compiler/mlir/ral/context/stream_executor_based_impl.cc b/tao_compiler/mlir/ral/context/stream_executor_based_impl.cc index 7c7aabbab2e..a5b41fe57f0 100644 --- a/tao_compiler/mlir/ral/context/stream_executor_based_impl.cc +++ b/tao_compiler/mlir/ral/context/stream_executor_based_impl.cc @@ -151,6 +151,11 @@ inline se::blas::ComputationType NativeTypeToBlasType() { return se::blas::ComputationType::kF64; } +template <> +inline se::blas::ComputationType NativeTypeToBlasType() { + return se::blas::ComputationType::kBF16AsF32; +} + // The template was introduced, because not all instantiation of // DoGemmWithAlgorithm template arguments was support by ThenBlasGemv. template @@ -293,7 +298,8 @@ se::blas::AlgorithmType tuningGemm(se::Stream* stream, se::blas::ProfileResult profile_result; DoGemmWithAlgorithm( /*batch_size*/ 1, lhs_matrix, rhs_matrix, output_matrix, - /*alpha*/ 1., /*beta*/ 0., stream, algorithm, &profile_result); + /*alpha*/ AlphaBeta(1.0), /*beta*/ AlphaBeta(0.0), stream, algorithm, + &profile_result); if (!profile_result.is_valid()) { TAO_VLOG(1) << "algo: " << algorithm << " is invalid."; @@ -1952,10 +1958,18 @@ void ral_qconv(ExecutionContext* ctx, void* stream_handle, } // namespace gpu // gemm ops -TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm); -TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm); +#ifndef DISC_BUILD_FROM_TF_BRIDGE +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16"); TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm); +TAO_RAL_API("ral_gemm", "gpu", + gpu::se_impl::ral_gemm); +TAO_RAL_API("ral_gemm", "gpu", + gpu::se_impl::ral_batch_gemm); +#endif +TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm); +TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_gemm); + TAO_RAL_API("ral_qgemm", "gpu", gpu::se_impl::ral_qgemm); TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm); TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm); @@ -1965,6 +1979,7 @@ TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm); TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm); + TAO_RAL_API("ral_gemm", "gpu", gpu::se_impl::ral_batch_gemm); #ifdef BLAZE_OPT diff --git a/tao_compiler/mlir/ral/context/tensorflow/tf_context_impl.cc b/tao_compiler/mlir/ral/context/tensorflow/tf_context_impl.cc index b265a23936e..2ae95723d4d 100644 --- a/tao_compiler/mlir/ral/context/tensorflow/tf_context_impl.cc +++ b/tao_compiler/mlir/ral/context/tensorflow/tf_context_impl.cc @@ -909,6 +909,7 @@ void ral_tf_send_output_0d(ExecutionContext* ctx, int64_t output_idx, TAO_RAL_API(::tao::ral::kRalBitcast, "cpu", ral_tf_bitcast_0d); RAL_REGISTER_IO_FUNC_0D(float); +RAL_REGISTER_IO_FUNC_0D(bfloat16); RAL_REGISTER_IO_FUNC_0D(double); RAL_REGISTER_IO_FUNC_0D(Eigen::half); RAL_REGISTER_IO_FUNC_0D(int8_t); @@ -980,6 +981,14 @@ RAL_REGISTER_IO_FUNC(bool, 5); RAL_REGISTER_IO_FUNC(bool, 6); RAL_REGISTER_IO_FUNC(bool, 7); RAL_REGISTER_IO_FUNC(bool, 8); +RAL_REGISTER_IO_FUNC(bfloat16, 1); +RAL_REGISTER_IO_FUNC(bfloat16, 2); +RAL_REGISTER_IO_FUNC(bfloat16, 3); +RAL_REGISTER_IO_FUNC(bfloat16, 4); +RAL_REGISTER_IO_FUNC(bfloat16, 5); +RAL_REGISTER_IO_FUNC(bfloat16, 6); +RAL_REGISTER_IO_FUNC(bfloat16, 7); +RAL_REGISTER_IO_FUNC(bfloat16, 8); } // namespace tensorflow @@ -987,6 +996,7 @@ namespace tao { namespace ral { DEFINE_TAO_TYPE_NAME_HELPER(Eigen::half, "f16"); +DEFINE_TAO_TYPE_NAME_HELPER(Eigen::bfloat16, "bf16"); TAO_RAL_API(::tao::ral::cpu::kRalCpuAlloc, "cpu", tensorflow::ral_tf_cpu_alloc); TAO_RAL_API(::tao::ral::cpu::kRalCpuAllocPersistent, "cpu",