Skip to content

Commit

Permalink
fix bug and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tocean committed Nov 23, 2023
1 parent 4438d62 commit d5d1423
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 28 deletions.
1 change: 1 addition & 0 deletions msamp/common/include/concurrency.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/concurrency.hpp.

#ifndef MSAMP_CONCURRENCY_H_
#define MSAMP_CONCURRENCY_H_
Expand Down
5 changes: 5 additions & 0 deletions msamp/common/include/poll.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// The file is from https://github.com/microsoft/mscclpp/blob/main/include/mscclpp/poll.hpp.

#ifndef MSAMP_POLL_H_
#define MSAMP_POLL_H_

Expand Down
3 changes: 3 additions & 0 deletions msamp/operators/arithmetic/arithmetic.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>

Expand Down
45 changes: 31 additions & 14 deletions msamp/operators/arithmetic/vectorized_pointwise.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
// The file is adapted from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/util/vectorized_pointwise.h.

#ifndef MSAMP_VECTORIZED_POINTWISE_H
#define MSAMP_VECTORIZED_POINTWISE_H

Expand Down Expand Up @@ -172,6 +176,9 @@ class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {


constexpr int unary_kernel_threads = 512;
constexpr float e4m3_max = 448.0;
constexpr float e5m2_max = 57344.0;

extern __device__ msamp::DeviceSyncer device_syncer;

template <int nvec, bool aligned,
Expand All @@ -186,6 +193,11 @@ __global__ void add_to_fp8_kernel(InputType *input,
ComputeType *amax,
const size_t N,
const size_t num_aligned_elements) {
if (threadIdx.x == 0 && blockIdx.x == 0) {
*amax = 0;
}
device_syncer.sync(gridDim.x, -1);

// input is high precision, output is fp8
VectorizedStorer<OutputType, nvec, aligned> output_storer(output, N);
VectorizedStorer<InputType, nvec, aligned> input_storer(input, N);
Expand All @@ -206,16 +218,24 @@ __global__ void add_to_fp8_kernel(InputType *input,
output_storer.load(tid, N);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const ComputeType val1 = static_cast<ComputeType>(output_storer.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(input_storer.separate()[i]);
const InputType val1 = static_cast<InputType>(input_storer.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(output_storer.separate()[i]);

InputType temp = static_cast<InputType>(val2 * s);

if constexpr (is_fp16<InputType>::value) {
temp = static_cast<ComputeType>(__hadd(temp, val1));
} else {
temp += val1;
}

ComputeType temp = val1 * s + val2;
if constexpr (is_fp8<OutputType>::value) {
__builtin_assume(max >= 0);
max = fmaxf(fabsf(static_cast<InputType>(temp)), max);
max = fmaxf(fabsf(temp), max);
}
}
}

if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/
max = transformer_engine::reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
Expand All @@ -235,14 +255,10 @@ __global__ void add_to_fp8_kernel(InputType *input,
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(exp < 0, 1 / sf, sf)
*/
ComputeType fp_max = 0.0;
if (std::is_same<OutputType, fp8e4m3>::value) {
fp_max = 448.0;
} else if (std::is_same<OutputType, fp8e5m2>::value) {
fp_max = 57344.0;
}

ComputeType amax_value = *amax;

ComputeType fp_max = std::is_same<OutputType, fp8e4m3>::value ? e4m3_max : e5m2_max;

ComputeType exp = floorf(log2f(fp_max/(amax_value)));
ComputeType sf = roundf(powf(2, fabsf(exp)));

Expand All @@ -254,8 +270,6 @@ __global__ void add_to_fp8_kernel(InputType *input,
sf = 1 / sf;
}

*scale = sf;

// using new scaling factor to quantize the input
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
Expand All @@ -279,8 +293,11 @@ __global__ void add_to_fp8_kernel(InputType *input,
}
output_storer.store(tid, N);
}
*scale_inv = 1.0 / sf;

if (threadIdx.x == 0 && blockIdx.x == 0) {
*scale = sf;
*scale_inv = 1.0 / sf;
}
}


Expand Down
31 changes: 17 additions & 14 deletions tests/operators/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Tests for arithmetic module."""

import itertools
import unittest

import torch
Expand All @@ -23,22 +24,24 @@ def _check_scaling_tensor(self, scaling_tensor1, scaling_tensor2):
@decorator.cuda_test
def test_add_to_fp8(self):
"""Test the function Arithmetic.add_to_fp8()."""

torch.manual_seed(100)
sizes = (10000, 16384)

for dtype in [torch.float16, torch.bfloat16, torch.float32]:
for qtype in [Dtypes.kfloat8_e4m3, Dtypes.kfloat8_e5m2]:
input1 = torch.rand(sizes, dtype=dtype, device='cuda')
input2 = torch.rand(sizes, dtype=dtype, device='cuda')

scaling_tensor1 = input1.cast(qtype)
scaling_tensor2 = input1.cast(qtype)

self._check_scaling_tensor(scaling_tensor1, scaling_tensor2)

sizes = []
for i in range(1024, 8192, 1024):
for j in range(1024, 8192, 1024):
sizes.append((i, j))

sizes = list(range(1024, 8193, 1024))
dtypes = [torch.float16, torch.bfloat16, torch.float32]
qtypes = [Dtypes.kfloat8_e4m3, Dtypes.kfloat8_e5m2]
for i, j, dtype, qtype, in itertools.product(sizes, sizes, dtypes, qtypes):
size = (i, j)
input1 = torch.rand(size, dtype=dtype, device='cuda')
scaling_tensor1 = input1.cast(qtype)
scaling_tensor2 = input1.cast(qtype)

for i in range(10):
input2 = torch.rand(size, dtype=dtype, device='cuda')
meta = scaling_tensor1.meta
Arithmetic.add_to_fp8(scaling_tensor1.value, meta, input2)
scaling_tensor2.copy_((scaling_tensor2.to(dtype) + input2).cast(qtype, meta=scaling_tensor2.meta))

self._check_scaling_tensor(scaling_tensor1, scaling_tensor2)

0 comments on commit d5d1423

Please sign in to comment.