diff --git a/msamp/common/include/concurrency.h b/msamp/common/include/concurrency.h index 49b0b1c2..6075f967 100644 --- a/msamp/common/include/concurrency.h +++ b/msamp/common/include/concurrency.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/concurrency.hpp. #ifndef MSAMP_CONCURRENCY_H_ #define MSAMP_CONCURRENCY_H_ diff --git a/msamp/common/include/poll.h b/msamp/common/include/poll.h index bf66d5f4..0ea53328 100644 --- a/msamp/common/include/poll.h +++ b/msamp/common/include/poll.h @@ -1,3 +1,8 @@ + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/poll.hpp. + #ifndef MSAMP_POLL_H_ #define MSAMP_POLL_H_ diff --git a/msamp/operators/arithmetic/arithmetic.cu b/msamp/operators/arithmetic/arithmetic.cu index a8786782..ba07ab18 100644 --- a/msamp/operators/arithmetic/arithmetic.cu +++ b/msamp/operators/arithmetic/arithmetic.cu @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + #include #include diff --git a/msamp/operators/arithmetic/vectorized_pointwise.h b/msamp/operators/arithmetic/vectorized_pointwise.h index ea1e9569..e90d8ae9 100644 --- a/msamp/operators/arithmetic/vectorized_pointwise.h +++ b/msamp/operators/arithmetic/vectorized_pointwise.h @@ -1,3 +1,7 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// The file is adapted from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/util/vectorized_pointwise.h. + #ifndef MSAMP_VECTORIZED_POINTWISE_H #define MSAMP_VECTORIZED_POINTWISE_H @@ -172,6 +176,9 @@ class VectorizedStorer : public VectorizedAccessor { constexpr int unary_kernel_threads = 512; +constexpr float e4m3_max = 448.0; +constexpr float e5m2_max = 57344.0; + extern __device__ msamp::DeviceSyncer device_syncer; template output_storer(output, N); VectorizedStorer input_storer(input, N); @@ -206,16 +218,24 @@ __global__ void add_to_fp8_kernel(InputType *input, output_storer.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { - const ComputeType val1 = static_cast(output_storer.separate()[i]); - const ComputeType val2 = static_cast(input_storer.separate()[i]); + const InputType val1 = static_cast(input_storer.separate()[i]); + const ComputeType val2 = static_cast(output_storer.separate()[i]); + + InputType temp = static_cast(val2 * s); + + if constexpr (is_fp16::value) { + temp = static_cast(__hadd(temp, val1)); + } else { + temp += val1; + } - ComputeType temp = val1 * s + val2; if constexpr (is_fp8::value) { __builtin_assume(max >= 0); - max = fmaxf(fabsf(static_cast(temp)), max); + max = fmaxf(fabsf(temp), max); } } } + if constexpr (is_fp8::value) { /* warp tile amax reduce*/ max = transformer_engine::reduce_max(max, warp_id); @@ -235,14 +255,10 @@ __global__ void add_to_fp8_kernel(InputType *input, sf = torch.where(torch.isfinite(amax), sf, scale) sf = torch.where(exp < 0, 1 / sf, sf) */ - ComputeType fp_max = 0.0; - if (std::is_same::value) { - fp_max = 448.0; - } else if (std::is_same::value) { - fp_max = 57344.0; - } - ComputeType amax_value = *amax; + + ComputeType fp_max = std::is_same::value ? e4m3_max : e5m2_max; + ComputeType exp = floorf(log2f(fp_max/(amax_value))); ComputeType sf = roundf(powf(2, fabsf(exp))); @@ -254,8 +270,6 @@ __global__ void add_to_fp8_kernel(InputType *input, sf = 1 / sf; } - *scale = sf; - // using new scaling factor to quantize the input for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; @@ -279,8 +293,11 @@ __global__ void add_to_fp8_kernel(InputType *input, } output_storer.store(tid, N); } - *scale_inv = 1.0 / sf; + if (threadIdx.x == 0 && blockIdx.x == 0) { + *scale = sf; + *scale_inv = 1.0 / sf; + } } diff --git a/tests/operators/test_arithmetic.py b/tests/operators/test_arithmetic.py index c92a84a2..439ab46d 100644 --- a/tests/operators/test_arithmetic.py +++ b/tests/operators/test_arithmetic.py @@ -3,6 +3,7 @@ """Tests for arithmetic module.""" +import itertools import unittest import torch @@ -23,22 +24,24 @@ def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2): @decorator.cuda_test def test_add_to_fp8(self): """Test the function Arithmetic.add_to_fp8().""" - torch.manual_seed(100) - sizes = (10000, 16384) - - for dtype in [torch.float16, torch.bfloat16, torch.float32]: - for qtype in [Dtypes.kfloat8_e4m3, Dtypes.kfloat8_e5m2]: - input1 = torch.rand(sizes, dtype=dtype, device='cuda') - input2 = torch.rand(sizes, dtype=dtype, device='cuda') - - scaling_tensor1 = input1.cast(qtype) - scaling_tensor2 = input1.cast(qtype) - - self._check_scaling_tensor(scaling_tensor1, scaling_tensor2) - + sizes = [] + for i in range(1024, 8192, 1024): + for j in range(1024, 8192, 1024): + sizes.append((i, j)) + + sizes = list(range(1024, 8193, 1024)) + dtypes = [torch.float16, torch.bfloat16, torch.float32] + qtypes = [Dtypes.kfloat8_e4m3, Dtypes.kfloat8_e5m2] + for i, j, dtype, qtype, in itertools.product(sizes, sizes, dtypes, qtypes): + size = (i, j) + input1 = torch.rand(size, dtype=dtype, device='cuda') + scaling_tensor1 = input1.cast(qtype) + scaling_tensor2 = input1.cast(qtype) + + for i in range(10): + input2 = torch.rand(size, dtype=dtype, device='cuda') meta = scaling_tensor1.meta Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2) scaling_tensor2.copy_((scaling_tensor2.to(dtype) + input2).cast(qtype, meta=scaling_tensor2.meta)) - self._check_scaling_tensor(scaling_tensor1, scaling_tensor2)