From 162bd9af22d1f74152cb6d500539afb336392535 Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 22 Nov 2023 09:58:41 +0000 Subject: [PATCH 1/6] optimize performance by fuse add high presicion tensor to fp8 tensor --- Makefile | 1 + msamp/{optim => common/include}/common.h | 68 ++- msamp/common/include/concurrency.h | 63 +++ msamp/common/include/poll.h | 51 +++ msamp/{optim => common/include}/utils.cuh | 0 msamp/megatron/distributed.py | 7 +- msamp/operators/arithmetic/__init__.py | 8 + msamp/operators/arithmetic/arithmetic.cu | 40 ++ msamp/operators/arithmetic/arithmetic.py | 33 ++ msamp/operators/arithmetic/setup.py | 37 ++ .../arithmetic/vectorized_pointwise.h | 387 ++++++++++++++++++ msamp/optim/adamw.cu | 4 +- tests/operators/test_arithmetic.py | 45 ++ 13 files changed, 736 insertions(+), 8 deletions(-) rename msamp/{optim => common/include}/common.h (54%) create mode 100644 msamp/common/include/concurrency.h create mode 100644 msamp/common/include/poll.h rename msamp/{optim => common/include}/utils.cuh (100%) create mode 100644 msamp/operators/arithmetic/__init__.py create mode 100644 msamp/operators/arithmetic/arithmetic.cu create mode 100644 msamp/operators/arithmetic/arithmetic.py create mode 100644 msamp/operators/arithmetic/setup.py create mode 100644 msamp/operators/arithmetic/vectorized_pointwise.h create mode 100644 tests/operators/test_arithmetic.py diff --git a/Makefile b/Makefile index bd7c5857..b5f2108d 100644 --- a/Makefile +++ b/Makefile @@ -23,4 +23,5 @@ lint: cpplint mdlint postinstall: cd msamp/operators/dist_op && bash build.sh && cd - + cd msamp/operators/arithmetics && pip install -v . && cd - cd msamp/optim && pip install -v . && cd - diff --git a/msamp/optim/common.h b/msamp/common/include/common.h similarity index 54% rename from msamp/optim/common.h rename to msamp/common/include/common.h index e559c437..39bc8790 100644 --- a/msamp/optim/common.h +++ b/msamp/common/include/common.h @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#ifndef MSAMP_COMMON_H_ +#define MSAMP_COMMON_H_ + #include #include #include @@ -20,6 +23,11 @@ using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +template +constexpr T DIVUP(const T &x, const T &y) { + return (((x) + ((y)-1)) / (y)); +} + #define TORCH_DTYPE_SWITCH(dtype, type, ...) \ switch (dtype) { \ case torch::kUInt8: { \ @@ -46,6 +54,36 @@ using fp8e5m2 = __nv_fp8_e5m2; throw "Unexcepted data type"; \ } +#define SELECT_FP8_TYPE(is_e4m3, type, ...) \ + if (is_e4m3){ \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } \ + else { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } + + +#define TORCH_DTYPE_SWITCH_INPUT(dtype, type, ...) \ + switch (dtype) { \ + case torch::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case torch::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case torch::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + throw "Unexcepted data type"; \ + } + + const int HIP_MAX_GRID_NUM = 65535; const int HIP_MAX_NUM_THREADS = 512; @@ -68,4 +106,32 @@ template <> __host__ __device__ bf16 cast_dtype(const float value) { return __fl template <> __host__ __device__ float cast_dtype(const fp16 value) { return __half2float(value); } -template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); } \ No newline at end of file +template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); } + +inline size_t product(const std::vector &shape) { + size_t ret = 1; + for (const auto &elem : shape) { + ret *= elem; + } + return ret; +} + +template +struct is_fp8 : std::false_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template <> +struct is_fp8 : std::true_type {}; + +template +struct is_fp16 : std::false_type {}; + +template <> +struct is_fp16 : std::true_type {}; + +template <> +struct is_fp16 : std::true_type {}; + +#endif // MSAMP_COMMON_H_ \ No newline at end of file diff --git a/msamp/common/include/concurrency.h b/msamp/common/include/concurrency.h new file mode 100644 index 00000000..49b0b1c2 --- /dev/null +++ b/msamp/common/include/concurrency.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSAMP_CONCURRENCY_H_ +#define MSAMP_CONCURRENCY_H_ + +#include "poll.h" + +namespace msamp { + +/// A device-wide barrier. +struct DeviceSyncer { + public: + /// Construct a new DeviceSyncer object. + DeviceSyncer() = default; + + /// Destroy the DeviceSyncer object. + ~DeviceSyncer() = default; + +#ifdef __CUDACC__ + /// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is + /// finished. + /// @param blockNum The number of blocks that will synchronize. + /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. + __forceinline__ __device__ void sync(int blockNum, int64_t maxSpinCount = 100000000) { + unsigned int maxOldCnt = blockNum - 1; + __syncthreads(); + if (blockNum == 1) return; + if (threadIdx.x == 0) { + // Need a `__threadfence()` before to flip `flag`. + __threadfence(); + int tmp = isIncFlag_ ^ 1; + if (tmp) { + if (atomicInc(&count_, maxOldCnt) == maxOldCnt) { + flag_ = 1; + } + POLL_MAYBE_JAILBREAK(!flag_, maxSpinCount); + } else { + if (atomicInc(&count_, maxOldCnt) == maxOldCnt) { + flag_ = 0; + } + POLL_MAYBE_JAILBREAK(flag_, maxSpinCount); + } + isIncFlag_ = tmp; + } + // We need sync here because only a single thread is checking whether + // the flag is flipped. + __syncthreads(); + } +#endif + + private: + /// The flag to indicate whether the barrier is reached by the latest thread. + volatile int flag_; + /// The counter of synchronized blocks. + unsigned int count_; + /// The flag to indicate whether to increase or decrease @ref flag_. + int isIncFlag_; +}; + +} // namespace msamp + +#endif // MSAMP_CONCURRENCY_H_ \ No newline at end of file diff --git a/msamp/common/include/poll.h b/msamp/common/include/poll.h new file mode 100644 index 00000000..bf66d5f4 --- /dev/null +++ b/msamp/common/include/poll.h @@ -0,0 +1,51 @@ +#ifndef MSAMP_POLL_H_ +#define MSAMP_POLL_H_ + +#include + +extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line, + const char *__function) __THROW; + +// If a spin is stuck, escape from it and set status to 1. +#define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \ + do { \ + int64_t __spin_cnt = 0; \ + __status = 0; \ + while (__cond) { \ + if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \ + __status = 1; \ + break; \ + } \ + } \ + } while (0); + +// If a spin is stuck, print a warning and keep spinning. +#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \ + do { \ + int64_t __spin_cnt = 0; \ + while (__cond) { \ + if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \ + __assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + } \ + } \ + } while (0); + +// the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2 +// this is specially useful when __cond1 is faster to check +#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \ + do { \ + int64_t __spin_cnt = 0; \ + while (true) { \ + if (!(__cond1)) { \ + break; \ + } else if (!(__cond2)) { \ + break; \ + } \ + if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \ + __assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + } \ + } \ + } while (0); + + +#endif // MSAMP_POLL_H_ \ No newline at end of file diff --git a/msamp/optim/utils.cuh b/msamp/common/include/utils.cuh similarity index 100% rename from msamp/optim/utils.cuh rename to msamp/common/include/utils.cuh diff --git a/msamp/megatron/distributed.py b/msamp/megatron/distributed.py index 379b8741..faa4c679 100644 --- a/msamp/megatron/distributed.py +++ b/msamp/megatron/distributed.py @@ -11,7 +11,7 @@ from msamp.common.dtype import Dtypes from msamp.common.tensor import ScalingMeta, ScalingTensor - +from msamp.operators.arithmetic import Arithmetic class FP8DistributedDataParallel(DistributedDataParallelBase): """A DDP with contiguous buffers and FP8 spport.""" @@ -177,10 +177,7 @@ def _fp8_make_param_hook(self, param): def param_hook(*unused): # Add the gradient to the buffer. if param.grad is not None: - param.main_grad.copy_( - (param.main_grad.to(param.grad.dtype) + - param.grad).cast(self.wgrad_qtype, meta=param.main_grad.meta) - ) + Arithmetic.add_to_fp8(param.main_grad.value, param.main_grad.meta, param.grad) # Now we can deallocate grad memory. param.grad = None diff --git a/msamp/operators/arithmetic/__init__.py b/msamp/operators/arithmetic/__init__.py new file mode 100644 index 00000000..55fc214a --- /dev/null +++ b/msamp/operators/arithmetic/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Exposes the interface of MS-AMP arithmetic module.""" + +from msamp.operators.arithmetic.arithmetic import Arithmetic + +__all__ = ['Arithmetic'] diff --git a/msamp/operators/arithmetic/arithmetic.cu b/msamp/operators/arithmetic/arithmetic.cu new file mode 100644 index 00000000..a8786782 --- /dev/null +++ b/msamp/operators/arithmetic/arithmetic.cu @@ -0,0 +1,40 @@ +#include +#include + +#include "../../common/include/common.h" +#include "../../common/include/utils.cuh" +#include "../../common/include/concurrency.h" +#include "vectorized_pointwise.h" + +namespace msamp { +void add_to_fp8(at::Tensor fp8_tensor, + at::Tensor scale, + at::Tensor scale_inv, + at::Tensor amax, + const at::Tensor& other, + bool is_e4m3) { + const size_t N = other.numel(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TORCH_DTYPE_SWITCH_INPUT(other.scalar_type(), IType, + SELECT_FP8_TYPE(is_e4m3, OType, + + constexpr int nvec = 32 / sizeof(IType); + + VectorizedAddToFp8KernelLauncher( + reinterpret_cast(other.data_ptr()), + reinterpret_cast(fp8_tensor.data_ptr()), + reinterpret_cast(scale.data_ptr()), + reinterpret_cast(scale_inv.data_ptr()), + reinterpret_cast(amax.data_ptr()), + N, + stream + ); + ); + ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("add_to_fp8", &add_to_fp8, "Add to fp8"); +} + +} // namespace msamp diff --git a/msamp/operators/arithmetic/arithmetic.py b/msamp/operators/arithmetic/arithmetic.py new file mode 100644 index 00000000..11d0b0d3 --- /dev/null +++ b/msamp/operators/arithmetic/arithmetic.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""FP8 Arithmetic module.""" + +import torch + +import msamp_arithmetic +from msamp.common.dtype import Dtypes + +class Arithmetic: + @staticmethod + def add_to_fp8(fp8_tensor, meta, other): + """Add high presicon tensor to fp8_tensor in-place. + + Args: + fp8_tensor (torch.Tensor): fp8 tensor to add to. + meta (ScalingTensorMeta): meta data of fp8_tensor. + other (torch.Tensor): high precision tensor to add. + """ + if not (fp8_tensor.is_cuda and fp8_tensor.is_contiguous): + raise ValueError('The fp8 tensor is not in cuda memory or contiguous.') + if not (other.is_cuda and other.is_contiguous): + raise ValueError('The other tensor is not in cuda memory or contiguous.') + if not (fp8_tensor.dtype == torch.uint8 or fp8_tensor.dtype == torch.int8): + raise ValueError('The fp8 tensor is not in uint8 or int8.') + + if not (meta.qtype == Dtypes.kfloat8_e4m3 or meta.qtype == Dtypes.kfloat8_e5m2): + raise ValueError('The fp8 tensor is not in e4m3 or e5m2 format.') + + is_e4m3 = meta.qtype == Dtypes.kfloat8_e4m3 + + msamp_arithmetic.add_to_fp8(fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], other, is_e4m3) diff --git a/msamp/operators/arithmetic/setup.py b/msamp/operators/arithmetic/setup.py new file mode 100644 index 00000000..e82d8c36 --- /dev/null +++ b/msamp/operators/arithmetic/setup.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""The setuptools based setup module.""" + +from setuptools import setup +from torch.utils import cpp_extension + +ext_t = cpp_extension.CUDAExtension +ext_fnames = ['arithmetic.cu'] +define_macros = [] +nvcc_flags = [ + "-O3", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math" + ] + +extra_compile_args = dict(cxx=['-fopenmp', '-O3'], nvcc=nvcc_flags) + +define_macros.append(('WITH_CUDA', None)) + +setup( + name='msamp_arithmetic', + version='0.0.1', + ext_modules=[ + ext_t( + 'msamp_arithmetic', + ext_fnames, + define_macros=define_macros, + extra_compile_args=extra_compile_args + ) + ], + cmdclass={'build_ext': cpp_extension.BuildExtension} +) diff --git a/msamp/operators/arithmetic/vectorized_pointwise.h b/msamp/operators/arithmetic/vectorized_pointwise.h new file mode 100644 index 00000000..ea1e9569 --- /dev/null +++ b/msamp/operators/arithmetic/vectorized_pointwise.h @@ -0,0 +1,387 @@ +#ifndef MSAMP_VECTORIZED_POINTWISE_H +#define MSAMP_VECTORIZED_POINTWISE_H + +#include +#include +#include + +#include "../../common/include/common.h" +#include "../../common/include/utils.cuh" +#include "../../common/include/concurrency.h" + +namespace msamp { +/* \brief Helper class that enables storing multiple values of type DType + as 1 value of type LType. +*/ +template +class VectorizedStorage { + public: + using LType = typename transformer_engine::BytesToType::Type; + constexpr static int nvec = n; + union vectorized_storage { + LType aligned; + DType separate[nvec]; // NOLINT(*) + + inline __device__ vectorized_storage() {} + inline __device__ ~vectorized_storage() {} + } scratch_; + + inline __device__ VectorizedStorage() {} + inline __device__ VectorizedStorage(const VectorizedStorage& y2) { + scratch_.aligned = y2.scratch_.aligned; + } + inline __device__ VectorizedStorage(const LType &y2) { + scratch_.aligned = y2; + } + inline __device__ VectorizedStorage& operator+=( + const VectorizedStorage& rhs) { + #pragma unroll + for (int i = 0; i < nvec; ++i) { + scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]); + } + return *this; + } + inline __device__ ~VectorizedStorage() {} +}; + +// Returns const LType is DType is const +template +struct select_const { + using type = LType; +}; + +template +struct select_const { + using type = const LType; +}; + + +/* \brief Helper class that enables accessing multiple values of type DType + as 1 value of type LType. Additional aligned template argument + allows performance optimizations if the pointer and the size of + the allocation is aligned to sizeof(LType) / sizeof(DType) elements. +*/ +template +class VectorizedAccessor { + public: + using StorageType = VectorizedStorage::type, + nvec>; + using LType = typename select_const::type; + StorageType storage_; + + LType* aligned_ptr_; + DType* unaligned_ptr_; + int alignment_; + size_t n_elems_; + + inline __device__ VectorizedAccessor(DType* const ptr, const size_t size) { + unaligned_ptr_ = ptr; + if (aligned) { + alignment_ = 0; + aligned_ptr_ = reinterpret_cast(ptr); + n_elems_ = (size + nvec - 1) / nvec; + } else { + size_t ptr_as_number = reinterpret_cast(ptr); + alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType); + aligned_ptr_ = reinterpret_cast(ptr - alignment_); + n_elems_ = (size + alignment_ + nvec - 1) / nvec; + } + } + + /* \brief Alignment of the input pointer in elements. */ + inline __device__ int alignment() const { + return alignment_; + } + + /* \brief Access to separate elements. */ + inline __device__ DType* separate() { + return storage_.scratch_.separate; + } + + /* \brief Number of aligned elements that span the entire input tensor. */ + inline __device__ size_t num_aligned_elements() const { + return n_elems_; + } + + /* \brief Load values from the input. + \param id Aligned index of the element. + \param N size of the tensor. + */ + inline __device__ void load(const size_t id, const size_t N) { + if (aligned) { + storage_.scratch_.aligned = aligned_ptr_[id]; + } else { + if (id > 0 && id < n_elems_ - 1) { + storage_.scratch_.aligned = aligned_ptr_[id]; + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + DType* ptr = reinterpret_cast(&(aligned_ptr_[id])) + j; + if (reinterpret_cast(ptr) >= reinterpret_cast(unaligned_ptr_) && + reinterpret_cast(ptr) < reinterpret_cast(unaligned_ptr_ + N)) { + storage_.scratch_.separate[j] = *ptr; + } else { + storage_.scratch_.separate[j] = DType(); + } + } + } + } + } +}; + +/* \brief Class used for vectorized read-only access. */ +template +class VectorizedLoader : public VectorizedAccessor { + public: + inline __device__ VectorizedLoader(const DType* ptr, const size_t N) : + VectorizedAccessor(ptr, N) { + } +}; + +/* \brief Class used for vectorized writable access. */ +template +class VectorizedStorer : public VectorizedAccessor { + public: + inline __device__ VectorizedStorer(DType* ptr, const size_t N) : + VectorizedAccessor(ptr, N) { + } + + /* \brief Store values to the output. + \param id Aligned index of the element. + \param N size of the tensor. + */ + inline __device__ void store(const size_t id, const size_t N) { + if (aligned) { + this->aligned_ptr_[id] = this->storage_.scratch_.aligned; + } else { + if (id > 0 && id < this->n_elems_ - 1) { + this->aligned_ptr_[id] = this->storage_.scratch_.aligned; + } else { +#pragma unroll + for (int j = 0; j < nvec; ++j) { + DType* ptr = reinterpret_cast(&(this->aligned_ptr_[id])) + j; + if (reinterpret_cast(ptr) >= reinterpret_cast(this->unaligned_ptr_) && + reinterpret_cast(ptr) < reinterpret_cast(this->unaligned_ptr_ + N)) { + *ptr = this->storage_.scratch_.separate[j]; + } + } + } + } + } +}; + + +constexpr int unary_kernel_threads = 512; +extern __device__ msamp::DeviceSyncer device_syncer; + +template +__launch_bounds__(unary_kernel_threads) +__global__ void add_to_fp8_kernel(InputType *input, + OutputType *output, + ComputeType *scale, + ComputeType *scale_inv, + ComputeType *amax, + const size_t N, + const size_t num_aligned_elements) { + // input is high precision, output is fp8 + VectorizedStorer output_storer(output, N); + VectorizedStorer input_storer(input, N); + + ComputeType max = 0; + ComputeType s = 0; + if constexpr (is_fp8::value) { + if (scale_inv != nullptr) s = *scale_inv; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + + const size_t M = num_aligned_elements; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + input_storer.load(tid, N); + output_storer.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const ComputeType val1 = static_cast(output_storer.separate()[i]); + const ComputeType val2 = static_cast(input_storer.separate()[i]); + + ComputeType temp = val1 * s + val2; + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(static_cast(temp)), max); + } + } + } + if constexpr (is_fp8::value) { + /* warp tile amax reduce*/ + max = transformer_engine::reduce_max(max, warp_id); + + if (threadIdx.x == 0 && amax != nullptr) { + static_assert(std::is_same::value); + transformer_engine::atomicMaxFloat(amax, max); + } + } + + device_syncer.sync(gridDim.x, -1); + + /* Compute scaling factor, translate the following python code to c++: + exp = torch.floor(torch.log2(fp_max / amax)) - margin + sf = torch.round(torch.pow(2, torch.abs(exp))) + sf = torch.where(amax > 0.0, sf, scale) + sf = torch.where(torch.isfinite(amax), sf, scale) + sf = torch.where(exp < 0, 1 / sf, sf) + */ + ComputeType fp_max = 0.0; + if (std::is_same::value) { + fp_max = 448.0; + } else if (std::is_same::value) { + fp_max = 57344.0; + } + + ComputeType amax_value = *amax; + ComputeType exp = floorf(log2f(fp_max/(amax_value))); + ComputeType sf = roundf(powf(2, fabsf(exp))); + + if (amax_value <= 0 || !isfinite(amax_value)) { + sf = *scale; + } + + if (exp < 0) { + sf = 1 / sf; + } + + *scale = sf; + + // using new scaling factor to quantize the input + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + tid < M; + tid += gridDim.x * blockDim.x) { + input_storer.load(tid, N); + output_storer.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const InputType val1 = static_cast(input_storer.separate()[i]); + const ComputeType val2 = static_cast(output_storer.separate()[i]); + + InputType temp1 = static_cast(val2 * s); + + if constexpr (is_fp16::value) { + temp1 = static_cast(__hadd(temp1, val1)); + } else { + temp1 += val1; + } + ComputeType temp2 = sf * static_cast(temp1); + output_storer.separate()[i] = static_cast(temp2); + } + output_storer.store(tid, N); + } + *scale_inv = 1.0 / sf; + +} + + +namespace { + +inline size_t get_num_aligned_elements(const void *ptr, const size_t lead_dim, + const int nvec, const int size) { + size_t ptr_as_number = reinterpret_cast(ptr); + int alignment = (ptr_as_number % (nvec * size)) / size; + return DIVUP(lead_dim + alignment, static_cast(nvec)); +} + +enum class Alignment { + SAME_ALIGNED, // All tensors aligned + SAME_UNALIGNED, // All tensors have the same misalignment + DIFFERENT // Tensors have different alignment +}; + +inline int CalcAlignment(const void *ptr, const int size) { + size_t ptr_as_number = reinterpret_cast(ptr); + return ptr_as_number % size; +} + +/* \brief Check alignment of the inputs and outputs when using vectorized accesses. + \param lead_dim Leading dimension of the tensors. + \param other_dim The size of the other dimensions of the tensors. + \param nvec Length of the vector. + \param ptrs Inputs and Outputs to the operator. +*/ +template +Alignment CheckAlignment(const size_t lead_dim, + const int nvec, + const T... ptrs + ) { + std::vector alignments; + alignments.reserve(sizeof...(T)); + + // calculate the alignments of all ptrs and store them into alignments + (..., alignments.push_back(CalcAlignment(ptrs, sizeof(*ptrs) * nvec))); + + bool all_same = std::all_of(alignments.cbegin(), alignments.cend(), + [alignments](int val) {return val == alignments.front();}); + if (!all_same) { + return Alignment::DIFFERENT; + } + + if (alignments.front() == 0 && + lead_dim % nvec == 0) { + // all alignment are 0 + return Alignment::SAME_ALIGNED; + } else { + return Alignment::SAME_UNALIGNED; + } +} + +} + +template +void VectorizedAddToFp8KernelLauncher(InputType *input, + OutputType *output, + fp32 *scale, + fp32 *scale_inv, + fp32 *amax, + const size_t N, + cudaStream_t stream) { + if (N != 0) { + auto align = CheckAlignment(N, nvec, input, output); + + size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, + sizeof(InputType)); + constexpr size_t threads = unary_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements, threads); + + // We use DeviceSyncer to sync the amax value between blocks, the block number should be less than + // (SMCount*MaxThreadsPerSM)/unary_kernel_threads, which is 132*2048/512 = 528 on H100 SXM. We set + // max_blocks to half of 528 to make sure it works on other H100 GPUs. + // constexpr size_t max_blocks = 65535; + constexpr size_t max_blocks = 264; + num_blocks = std::min(num_blocks, max_blocks); + + switch (align) { + case Alignment::SAME_ALIGNED: + add_to_fp8_kernel<<>>( + input, output, scale, scale_inv, amax, N, num_aligned_elements); + break; + case Alignment::SAME_UNALIGNED: + add_to_fp8_kernel<<>>( + input, output, scale, scale_inv, amax, N, num_aligned_elements); + break; + case Alignment::DIFFERENT: { + // If the pointers are aligned differently we cannot vectorize + add_to_fp8_kernel<1, true, fp32><<>>( + input, output, scale, scale_inv, amax, N, num_aligned_elements); + break; + } + } + } +} + +} // namespace msamp + +#endif // MSAMP_VECTORIZED_POINTWISE_H \ No newline at end of file diff --git a/msamp/optim/adamw.cu b/msamp/optim/adamw.cu index fb4c7468..a8f6c222 100644 --- a/msamp/optim/adamw.cu +++ b/msamp/optim/adamw.cu @@ -9,8 +9,8 @@ #include #include -#include "common.h" -#include "utils.cuh" +#include "../common/include/common.h" +#include "../common/include/utils.cuh" using namespace std; using namespace torch; diff --git a/tests/operators/test_arithmetic.py b/tests/operators/test_arithmetic.py new file mode 100644 index 00000000..e22ddd51 --- /dev/null +++ b/tests/operators/test_arithmetic.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for arithmetic module.""" + +import unittest + +import torch + +from tests.helper import decorator +from msamp.common.dtype import Dtypes +from msamp.operators.arithmetic import Arithmetic +from msamp.common.tensor import TypeCast + +class ArithmeticTestCase(unittest.TestCase): + """A class for Arithmetic test cases.""" + + def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2): + assert torch.all(torch.eq(scaling_tensor1.value, scaling_tensor2.value)) + assert torch.all(torch.eq(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale)) + assert torch.all(torch.eq(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv)) + assert torch.all(torch.eq(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax)) + + @decorator.cuda_test + def test_add_to_fp8(self): + """Test the function Arithmetic.add_to_fp8().""" + + torch.manual_seed(100) + sizes = (10000, 16384) + + for dtype in [torch.float16, torch.bfloat16, torch.float32]: + for qtype in [Dtypes.kfloat8_e4m3, Dtypes.kfloat8_e5m2]: + input1 = torch.rand(sizes, dtype=dtype, device='cuda') + input2 = torch.rand(sizes, dtype=dtype, device='cuda') + + scaling_tensor1 = input1.cast(qtype) + scaling_tensor2 = input1.cast(qtype) + + self._check_scaling_tensor(scaling_tensor1, scaling_tensor2) + + meta = scaling_tensor1.meta + Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2) + scaling_tensor2.copy_((scaling_tensor2.to(dtype) + input2).cast(qtype, meta=scaling_tensor2.meta)) + + self._check_scaling_tensor(scaling_tensor1, scaling_tensor2) From a0c784eee82ae4daf82afbe744c58d957c66c831 Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 22 Nov 2023 10:06:09 +0000 Subject: [PATCH 2/6] remove product in common.h --- msamp/common/include/common.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/msamp/common/include/common.h b/msamp/common/include/common.h index 39bc8790..09baff0c 100644 --- a/msamp/common/include/common.h +++ b/msamp/common/include/common.h @@ -108,14 +108,6 @@ template <> __host__ __device__ float cast_dtype(const fp16 value) { return __ha template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); } -inline size_t product(const std::vector &shape) { - size_t ret = 1; - for (const auto &elem : shape) { - ret *= elem; - } - return ret; -} - template struct is_fp8 : std::false_type {}; From 8e0239d10a2f46c52de610b74a5266a2b27b446d Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 22 Nov 2023 10:13:06 +0000 Subject: [PATCH 3/6] fix typos in Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index b5f2108d..a37f66fb 100644 --- a/Makefile +++ b/Makefile @@ -23,5 +23,5 @@ lint: cpplint mdlint postinstall: cd msamp/operators/dist_op && bash build.sh && cd - - cd msamp/operators/arithmetics && pip install -v . && cd - + cd msamp/operators/arithmetic && pip install -v . && cd - cd msamp/optim && pip install -v . && cd - From 4438d6296734e7537db438cf521039582480c4aa Mon Sep 17 00:00:00 2001 From: tocean Date: Wed, 22 Nov 2023 10:49:42 +0000 Subject: [PATCH 4/6] fix lint issues --- msamp/megatron/distributed.py | 1 + msamp/operators/arithmetic/arithmetic.py | 6 ++++-- msamp/operators/arithmetic/setup.py | 17 ++++------------- tests/operators/test_arithmetic.py | 3 +-- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/msamp/megatron/distributed.py b/msamp/megatron/distributed.py index faa4c679..a2079f99 100644 --- a/msamp/megatron/distributed.py +++ b/msamp/megatron/distributed.py @@ -13,6 +13,7 @@ from msamp.common.tensor import ScalingMeta, ScalingTensor from msamp.operators.arithmetic import Arithmetic + class FP8DistributedDataParallel(DistributedDataParallelBase): """A DDP with contiguous buffers and FP8 spport.""" wgrad_qtype = Dtypes.kfloat8_e4m3 diff --git a/msamp/operators/arithmetic/arithmetic.py b/msamp/operators/arithmetic/arithmetic.py index 11d0b0d3..1c5ed748 100644 --- a/msamp/operators/arithmetic/arithmetic.py +++ b/msamp/operators/arithmetic/arithmetic.py @@ -8,11 +8,13 @@ import msamp_arithmetic from msamp.common.dtype import Dtypes + class Arithmetic: + """Arithmetic operator for FP8 tensor.""" @staticmethod def add_to_fp8(fp8_tensor, meta, other): """Add high presicon tensor to fp8_tensor in-place. - + Args: fp8_tensor (torch.Tensor): fp8 tensor to add to. meta (ScalingTensorMeta): meta data of fp8_tensor. @@ -24,7 +26,7 @@ def add_to_fp8(fp8_tensor, meta, other): raise ValueError('The other tensor is not in cuda memory or contiguous.') if not (fp8_tensor.dtype == torch.uint8 or fp8_tensor.dtype == torch.int8): raise ValueError('The fp8 tensor is not in uint8 or int8.') - + if not (meta.qtype == Dtypes.kfloat8_e4m3 or meta.qtype == Dtypes.kfloat8_e5m2): raise ValueError('The fp8 tensor is not in e4m3 or e5m2 format.') diff --git a/msamp/operators/arithmetic/setup.py b/msamp/operators/arithmetic/setup.py index e82d8c36..fc5ad369 100644 --- a/msamp/operators/arithmetic/setup.py +++ b/msamp/operators/arithmetic/setup.py @@ -10,13 +10,9 @@ ext_fnames = ['arithmetic.cu'] define_macros = [] nvcc_flags = [ - "-O3", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math" - ] + '-O3', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '--expt-relaxed-constexpr', + '--expt-extended-lambda', '--use_fast_math' +] extra_compile_args = dict(cxx=['-fopenmp', '-O3'], nvcc=nvcc_flags) @@ -26,12 +22,7 @@ name='msamp_arithmetic', version='0.0.1', ext_modules=[ - ext_t( - 'msamp_arithmetic', - ext_fnames, - define_macros=define_macros, - extra_compile_args=extra_compile_args - ) + ext_t('msamp_arithmetic', ext_fnames, define_macros=define_macros, extra_compile_args=extra_compile_args) ], cmdclass={'build_ext': cpp_extension.BuildExtension} ) diff --git a/tests/operators/test_arithmetic.py b/tests/operators/test_arithmetic.py index e22ddd51..c92a84a2 100644 --- a/tests/operators/test_arithmetic.py +++ b/tests/operators/test_arithmetic.py @@ -10,11 +10,10 @@ from tests.helper import decorator from msamp.common.dtype import Dtypes from msamp.operators.arithmetic import Arithmetic -from msamp.common.tensor import TypeCast + class ArithmeticTestCase(unittest.TestCase): """A class for Arithmetic test cases.""" - def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2): assert torch.all(torch.eq(scaling_tensor1.value, scaling_tensor2.value)) assert torch.all(torch.eq(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale)) From d5d14234eed55adc5096a8d5e0c7024b63a13ed2 Mon Sep 17 00:00:00 2001 From: tocean Date: Thu, 23 Nov 2023 07:02:23 +0000 Subject: [PATCH 5/6] fix bug and comments --- msamp/common/include/concurrency.h | 1 + msamp/common/include/poll.h | 5 +++ msamp/operators/arithmetic/arithmetic.cu | 3 ++ .../arithmetic/vectorized_pointwise.h | 45 +++++++++++++------ tests/operators/test_arithmetic.py | 31 +++++++------ 5 files changed, 57 insertions(+), 28 deletions(-) diff --git a/msamp/common/include/concurrency.h b/msamp/common/include/concurrency.h index 49b0b1c2..6075f967 100644 --- a/msamp/common/include/concurrency.h +++ b/msamp/common/include/concurrency.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/concurrency.hpp. #ifndef MSAMP_CONCURRENCY_H_ #define MSAMP_CONCURRENCY_H_ diff --git a/msamp/common/include/poll.h b/msamp/common/include/poll.h index bf66d5f4..0ea53328 100644 --- a/msamp/common/include/poll.h +++ b/msamp/common/include/poll.h @@ -1,3 +1,8 @@ + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/poll.hpp. + #ifndef MSAMP_POLL_H_ #define MSAMP_POLL_H_ diff --git a/msamp/operators/arithmetic/arithmetic.cu b/msamp/operators/arithmetic/arithmetic.cu index a8786782..ba07ab18 100644 --- a/msamp/operators/arithmetic/arithmetic.cu +++ b/msamp/operators/arithmetic/arithmetic.cu @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + #include #include diff --git a/msamp/operators/arithmetic/vectorized_pointwise.h b/msamp/operators/arithmetic/vectorized_pointwise.h index ea1e9569..e90d8ae9 100644 --- a/msamp/operators/arithmetic/vectorized_pointwise.h +++ b/msamp/operators/arithmetic/vectorized_pointwise.h @@ -1,3 +1,7 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// The file is adapted from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/util/vectorized_pointwise.h. + #ifndef MSAMP_VECTORIZED_POINTWISE_H #define MSAMP_VECTORIZED_POINTWISE_H @@ -172,6 +176,9 @@ class VectorizedStorer : public VectorizedAccessor { constexpr int unary_kernel_threads = 512; +constexpr float e4m3_max = 448.0; +constexpr float e5m2_max = 57344.0; + extern __device__ msamp::DeviceSyncer device_syncer; template output_storer(output, N); VectorizedStorer input_storer(input, N); @@ -206,16 +218,24 @@ __global__ void add_to_fp8_kernel(InputType *input, output_storer.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { - const ComputeType val1 = static_cast(output_storer.separate()[i]); - const ComputeType val2 = static_cast(input_storer.separate()[i]); + const InputType val1 = static_cast(input_storer.separate()[i]); + const ComputeType val2 = static_cast(output_storer.separate()[i]); + + InputType temp = static_cast(val2 * s); + + if constexpr (is_fp16::value) { + temp = static_cast(__hadd(temp, val1)); + } else { + temp += val1; + } - ComputeType temp = val1 * s + val2; if constexpr (is_fp8::value) { __builtin_assume(max >= 0); - max = fmaxf(fabsf(static_cast(temp)), max); + max = fmaxf(fabsf(temp), max); } } } + if constexpr (is_fp8::value) { /* warp tile amax reduce*/ max = transformer_engine::reduce_max(max, warp_id); @@ -235,14 +255,10 @@ __global__ void add_to_fp8_kernel(InputType *input, sf = torch.where(torch.isfinite(amax), sf, scale) sf = torch.where(exp < 0, 1 / sf, sf) */ - ComputeType fp_max = 0.0; - if (std::is_same::value) { - fp_max = 448.0; - } else if (std::is_same::value) { - fp_max = 57344.0; - } - ComputeType amax_value = *amax; + + ComputeType fp_max = std::is_same::value ? e4m3_max : e5m2_max; + ComputeType exp = floorf(log2f(fp_max/(amax_value))); ComputeType sf = roundf(powf(2, fabsf(exp))); @@ -254,8 +270,6 @@ __global__ void add_to_fp8_kernel(InputType *input, sf = 1 / sf; } - *scale = sf; - // using new scaling factor to quantize the input for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; @@ -279,8 +293,11 @@ __global__ void add_to_fp8_kernel(InputType *input, } output_storer.store(tid, N); } - *scale_inv = 1.0 / sf; + if (threadIdx.x == 0 && blockIdx.x == 0) { + *scale = sf; + *scale_inv = 1.0 / sf; + } } diff --git a/tests/operators/test_arithmetic.py b/tests/operators/test_arithmetic.py index c92a84a2..439ab46d 100644 --- a/tests/operators/test_arithmetic.py +++ b/tests/operators/test_arithmetic.py @@ -3,6 +3,7 @@ """Tests for arithmetic module.""" +import itertools import unittest import torch @@ -23,22 +24,24 @@ def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2): @decorator.cuda_test def test_add_to_fp8(self): """Test the function Arithmetic.add_to_fp8().""" - torch.manual_seed(100) - sizes = (10000, 16384) - - for dtype in [torch.float16, torch.bfloat16, torch.float32]: - for qtype in [Dtypes.kfloat8_e4m3, Dtypes.kfloat8_e5m2]: - input1 = torch.rand(sizes, dtype=dtype, device='cuda') - input2 = torch.rand(sizes, dtype=dtype, device='cuda') - - scaling_tensor1 = input1.cast(qtype) - scaling_tensor2 = input1.cast(qtype) - - self._check_scaling_tensor(scaling_tensor1, scaling_tensor2) - + sizes = [] + for i in range(1024, 8192, 1024): + for j in range(1024, 8192, 1024): + sizes.append((i, j)) + + sizes = list(range(1024, 8193, 1024)) + dtypes = [torch.float16, torch.bfloat16, torch.float32] + qtypes = [Dtypes.kfloat8_e4m3, Dtypes.kfloat8_e5m2] + for i, j, dtype, qtype, in itertools.product(sizes, sizes, dtypes, qtypes): + size = (i, j) + input1 = torch.rand(size, dtype=dtype, device='cuda') + scaling_tensor1 = input1.cast(qtype) + scaling_tensor2 = input1.cast(qtype) + + for i in range(10): + input2 = torch.rand(size, dtype=dtype, device='cuda') meta = scaling_tensor1.meta Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2) scaling_tensor2.copy_((scaling_tensor2.to(dtype) + input2).cast(qtype, meta=scaling_tensor2.meta)) - self._check_scaling_tensor(scaling_tensor1, scaling_tensor2) From e1c8f216c82dca46800f0cf431b22fca940bec2c Mon Sep 17 00:00:00 2001 From: tocean Date: Fri, 24 Nov 2023 03:05:39 +0000 Subject: [PATCH 6/6] fix comments --- msamp/common/include/common.h | 6 +++--- msamp/operators/arithmetic/vectorized_pointwise.h | 4 ++-- tests/operators/test_arithmetic.py | 13 ++++--------- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/msamp/common/include/common.h b/msamp/common/include/common.h index 09baff0c..9b205d69 100644 --- a/msamp/common/include/common.h +++ b/msamp/common/include/common.h @@ -118,12 +118,12 @@ template <> struct is_fp8 : std::true_type {}; template -struct is_fp16 : std::false_type {}; +struct is_half : std::false_type {}; template <> -struct is_fp16 : std::true_type {}; +struct is_half : std::true_type {}; template <> -struct is_fp16 : std::true_type {}; +struct is_half : std::true_type {}; #endif // MSAMP_COMMON_H_ \ No newline at end of file diff --git a/msamp/operators/arithmetic/vectorized_pointwise.h b/msamp/operators/arithmetic/vectorized_pointwise.h index e90d8ae9..bd765637 100644 --- a/msamp/operators/arithmetic/vectorized_pointwise.h +++ b/msamp/operators/arithmetic/vectorized_pointwise.h @@ -223,7 +223,7 @@ __global__ void add_to_fp8_kernel(InputType *input, InputType temp = static_cast(val2 * s); - if constexpr (is_fp16::value) { + if constexpr (is_half::value) { temp = static_cast(__hadd(temp, val1)); } else { temp += val1; @@ -283,7 +283,7 @@ __global__ void add_to_fp8_kernel(InputType *input, InputType temp1 = static_cast(val2 * s); - if constexpr (is_fp16::value) { + if constexpr (is_half::value) { temp1 = static_cast(__hadd(temp1, val1)); } else { temp1 += val1; diff --git a/tests/operators/test_arithmetic.py b/tests/operators/test_arithmetic.py index 439ab46d..23386991 100644 --- a/tests/operators/test_arithmetic.py +++ b/tests/operators/test_arithmetic.py @@ -16,20 +16,15 @@ class ArithmeticTestCase(unittest.TestCase): """A class for Arithmetic test cases.""" def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2): - assert torch.all(torch.eq(scaling_tensor1.value, scaling_tensor2.value)) - assert torch.all(torch.eq(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale)) - assert torch.all(torch.eq(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv)) - assert torch.all(torch.eq(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax)) + self.assertTrue(torch.all(torch.eq(scaling_tensor1.value, scaling_tensor2.value))) + self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale, scaling_tensor2.meta.scale))) + self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.scale_inv, scaling_tensor2.meta.scale_inv))) + self.assertTrue(torch.all(torch.eq(scaling_tensor1.meta.amax, scaling_tensor2.meta.amax))) @decorator.cuda_test def test_add_to_fp8(self): """Test the function Arithmetic.add_to_fp8().""" torch.manual_seed(100) - sizes = [] - for i in range(1024, 8192, 1024): - for j in range(1024, 8192, 1024): - sizes.append((i, j)) - sizes = list(range(1024, 8193, 1024)) dtypes = [torch.float16, torch.bfloat16, torch.float32] qtypes = [Dtypes.kfloat8_e4m3, Dtypes.kfloat8_e5m2]