Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize performance by fuse adding high precision tensor to fp8 tensor #132

Merged
merged 6 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ lint: cpplint mdlint

postinstall:
cd msamp/operators/dist_op && bash build.sh && cd -
cd msamp/operators/arithmetic && pip install -v . && cd -
cd msamp/optim && pip install -v . && cd -
60 changes: 59 additions & 1 deletion msamp/optim/common.h → msamp/common/include/common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#ifndef MSAMP_COMMON_H_
#define MSAMP_COMMON_H_

#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
Expand All @@ -20,6 +23,11 @@ using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;

template <typename T>
constexpr T DIVUP(const T &x, const T &y) {
return (((x) + ((y)-1)) / (y));
}

#define TORCH_DTYPE_SWITCH(dtype, type, ...) \
switch (dtype) { \
case torch::kUInt8: { \
Expand All @@ -46,6 +54,36 @@ using fp8e5m2 = __nv_fp8_e5m2;
throw "Unexcepted data type"; \
}

#define SELECT_FP8_TYPE(is_e4m3, type, ...) \
if (is_e4m3){ \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} \
else { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
}


#define TORCH_DTYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
case torch::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case torch::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case torch::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
default: \
throw "Unexcepted data type"; \
}


const int HIP_MAX_GRID_NUM = 65535;
const int HIP_MAX_NUM_THREADS = 512;

Expand All @@ -68,4 +106,24 @@ template <> __host__ __device__ bf16 cast_dtype(const float value) { return __fl

template <> __host__ __device__ float cast_dtype(const fp16 value) { return __half2float(value); }

template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); }
template <> __host__ __device__ float cast_dtype(const bf16 value) { return __bfloat162float(value); }

template <typename T>
struct is_fp8 : std::false_type {};

template <>
struct is_fp8<fp8e4m3> : std::true_type {};

template <>
struct is_fp8<fp8e5m2> : std::true_type {};

template <typename T>
struct is_fp16 : std::false_type {};

template <>
struct is_fp16<fp16> : std::true_type {};

template <>
tocean marked this conversation as resolved.
Show resolved Hide resolved
struct is_fp16<bf16> : std::true_type {};

#endif // MSAMP_COMMON_H_
63 changes: 63 additions & 0 deletions msamp/common/include/concurrency.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#ifndef MSAMP_CONCURRENCY_H_
#define MSAMP_CONCURRENCY_H_

#include "poll.h"

namespace msamp {

/// A device-wide barrier.
struct DeviceSyncer {
public:
/// Construct a new DeviceSyncer object.
DeviceSyncer() = default;

/// Destroy the DeviceSyncer object.
~DeviceSyncer() = default;

#ifdef __CUDACC__
/// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is
/// finished.
/// @param blockNum The number of blocks that will synchronize.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
__forceinline__ __device__ void sync(int blockNum, int64_t maxSpinCount = 100000000) {
unsigned int maxOldCnt = blockNum - 1;
__syncthreads();
if (blockNum == 1) return;
if (threadIdx.x == 0) {
// Need a `__threadfence()` before to flip `flag`.
__threadfence();
int tmp = isIncFlag_ ^ 1;
if (tmp) {
if (atomicInc(&count_, maxOldCnt) == maxOldCnt) {
flag_ = 1;
}
POLL_MAYBE_JAILBREAK(!flag_, maxSpinCount);
} else {
if (atomicInc(&count_, maxOldCnt) == maxOldCnt) {
flag_ = 0;
}
POLL_MAYBE_JAILBREAK(flag_, maxSpinCount);
}
isIncFlag_ = tmp;
}
// We need sync here because only a single thread is checking whether
// the flag is flipped.
__syncthreads();
}
#endif

private:
/// The flag to indicate whether the barrier is reached by the latest thread.
volatile int flag_;
/// The counter of synchronized blocks.
unsigned int count_;
/// The flag to indicate whether to increase or decrease @ref flag_.
int isIncFlag_;
};

} // namespace msamp

#endif // MSAMP_CONCURRENCY_H_
51 changes: 51 additions & 0 deletions msamp/common/include/poll.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef MSAMP_POLL_H_
#define MSAMP_POLL_H_

#include <cstdint>

extern "C" __device__ void __assert_fail(const char *__assertion, const char *__file, unsigned int __line,
const char *__function) __THROW;

// If a spin is stuck, escape from it and set status to 1.
#define POLL_MAYBE_JAILBREAK_ESCAPE(__cond, __max_spin_cnt, __status) \
do { \
int64_t __spin_cnt = 0; \
__status = 0; \
while (__cond) { \
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
__status = 1; \
break; \
} \
} \
} while (0);

// If a spin is stuck, print a warning and keep spinning.
#define POLL_MAYBE_JAILBREAK(__cond, __max_spin_cnt) \
do { \
int64_t __spin_cnt = 0; \
while (__cond) { \
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
__assert_fail(#__cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
} \
} \
} while (0);

// the as POLL_MAYBE_JAILBREAK except that __cond1 is checked before __cond2
// this is specially useful when __cond1 is faster to check
#define OR_POLL_MAYBE_JAILBREAK(__cond1, __cond2, __max_spin_cnt) \
do { \
int64_t __spin_cnt = 0; \
while (true) { \
if (!(__cond1)) { \
break; \
} else if (!(__cond2)) { \
break; \
} \
if (__max_spin_cnt >= 0 && __spin_cnt++ == __max_spin_cnt) { \
__assert_fail(#__cond1 #__cond2, __FILE__, __LINE__, __PRETTY_FUNCTION__); \
} \
} \
} while (0);


#endif // MSAMP_POLL_H_
File renamed without changes.
6 changes: 2 additions & 4 deletions msamp/megatron/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from msamp.common.dtype import Dtypes
from msamp.common.tensor import ScalingMeta, ScalingTensor
from msamp.operators.arithmetic import Arithmetic


class FP8DistributedDataParallel(DistributedDataParallelBase):
Expand Down Expand Up @@ -177,10 +178,7 @@ def _fp8_make_param_hook(self, param):
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad is not None:
param.main_grad.copy_(
(param.main_grad.to(param.grad.dtype) +
param.grad).cast(self.wgrad_qtype, meta=param.main_grad.meta)
)
Arithmetic.add_to_fp8(param.main_grad.value, param.main_grad.meta, param.grad)
# Now we can deallocate grad memory.
param.grad = None

Expand Down
8 changes: 8 additions & 0 deletions msamp/operators/arithmetic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Exposes the interface of MS-AMP arithmetic module."""

from msamp.operators.arithmetic.arithmetic import Arithmetic

__all__ = ['Arithmetic']
40 changes: 40 additions & 0 deletions msamp/operators/arithmetic/arithmetic.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>

#include "../../common/include/common.h"
#include "../../common/include/utils.cuh"
#include "../../common/include/concurrency.h"
#include "vectorized_pointwise.h"

namespace msamp {
void add_to_fp8(at::Tensor fp8_tensor,
at::Tensor scale,
at::Tensor scale_inv,
at::Tensor amax,
const at::Tensor& other,
bool is_e4m3) {
const size_t N = other.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_DTYPE_SWITCH_INPUT(other.scalar_type(), IType,
SELECT_FP8_TYPE(is_e4m3, OType,

constexpr int nvec = 32 / sizeof(IType);

VectorizedAddToFp8KernelLauncher<nvec>(
reinterpret_cast<IType*>(other.data_ptr()),
reinterpret_cast<OType*>(fp8_tensor.data_ptr()),
reinterpret_cast<fp32*>(scale.data_ptr()),
reinterpret_cast<fp32*>(scale_inv.data_ptr()),
reinterpret_cast<fp32*>(amax.data_ptr()),
N,
stream
);
);
);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("add_to_fp8", &add_to_fp8, "Add to fp8");
}

} // namespace msamp
35 changes: 35 additions & 0 deletions msamp/operators/arithmetic/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""FP8 Arithmetic module."""

import torch

import msamp_arithmetic
from msamp.common.dtype import Dtypes


class Arithmetic:
"""Arithmetic operator for FP8 tensor."""
@staticmethod
def add_to_fp8(fp8_tensor, meta, other):
"""Add high presicon tensor to fp8_tensor in-place.

Args:
fp8_tensor (torch.Tensor): fp8 tensor to add to.
meta (ScalingTensorMeta): meta data of fp8_tensor.
other (torch.Tensor): high precision tensor to add.
"""
if not (fp8_tensor.is_cuda and fp8_tensor.is_contiguous):
raise ValueError('The fp8 tensor is not in cuda memory or contiguous.')
if not (other.is_cuda and other.is_contiguous):
raise ValueError('The other tensor is not in cuda memory or contiguous.')
if not (fp8_tensor.dtype == torch.uint8 or fp8_tensor.dtype == torch.int8):
raise ValueError('The fp8 tensor is not in uint8 or int8.')

if not (meta.qtype == Dtypes.kfloat8_e4m3 or meta.qtype == Dtypes.kfloat8_e5m2):
raise ValueError('The fp8 tensor is not in e4m3 or e5m2 format.')

is_e4m3 = meta.qtype == Dtypes.kfloat8_e4m3

msamp_arithmetic.add_to_fp8(fp8_tensor, meta.scale, meta.scale_inv, meta.amax[0], other, is_e4m3)
28 changes: 28 additions & 0 deletions msamp/operators/arithmetic/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""The setuptools based setup module."""

from setuptools import setup
from torch.utils import cpp_extension

ext_t = cpp_extension.CUDAExtension
ext_fnames = ['arithmetic.cu']
define_macros = []
nvcc_flags = [
'-O3', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '--expt-relaxed-constexpr',
'--expt-extended-lambda', '--use_fast_math'
]

extra_compile_args = dict(cxx=['-fopenmp', '-O3'], nvcc=nvcc_flags)

define_macros.append(('WITH_CUDA', None))

setup(
name='msamp_arithmetic',
version='0.0.1',
ext_modules=[
ext_t('msamp_arithmetic', ext_fnames, define_macros=define_macros, extra_compile_args=extra_compile_args)
],
cmdclass={'build_ext': cpp_extension.BuildExtension}
)
Loading
Loading