Skip to content

Commit

Permalink
[Inference] Delete duplicated copy_vector (hpcaitech#5716)
Browse files Browse the repository at this point in the history
  • Loading branch information
Courtesy-Xs authored May 14, 2024
1 parent 7806842 commit 121d7ad
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "funcs/cast_functor.h"
#include "common/micros.h"

using colossalAI::cuda::utils::copy_vector;
using colossalAI::cuda::utils::get_vec_size;
using colossalAI::cuda::utils::copy;
using colossalAI::funcs::CastFunctor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "funcs/cast_functor.h"
#include "funcs/binary_functor.h"

using colossalAI::cuda::utils::copy_vector;
using colossalAI::cuda::utils::get_vec_size;
using colossalAI::cuda::utils::copy;
using colossalAI::funcs::CastFunctor;
Expand Down
6 changes: 3 additions & 3 deletions extensions/csrc/kernel/cuda/get_cos_and_sin_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "utils/vec_copy.h"
#include "common/micros.h"

using colossalAI::cuda::utils::copy_vector;
using colossalAI::cuda::utils::copy;
using colossalAI::cuda::utils::get_vec_size;


Expand All @@ -23,8 +23,8 @@ __device__ void apply_cos_and_sin_memcopy(
int begin_id = threadIdx.x * VecSize;

for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){
copy_vector<scalar_t, VecSize>(cos + dest_offset_id + begin_id, cos_cache_ptr + src_offset_id + begin_id);
copy_vector<scalar_t, VecSize>(sin + dest_offset_id + begin_id, sin_cache_ptr + src_offset_id + begin_id);
copy<scalar_t, VecSize>(cos_cache_ptr + src_offset_id + begin_id, cos + dest_offset_id + begin_id);
copy<scalar_t, VecSize>(sin_cache_ptr + src_offset_id + begin_id, sin + dest_offset_id + begin_id);
}

if (!Aligned) {
Expand Down
22 changes: 11 additions & 11 deletions extensions/csrc/kernel/cuda/scaled_masked_softmax_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using colossalAI::funcs::UnaryOpFunctor;
using colossalAI::funcs::UnaryOpType;
using colossalAI::funcs::warp_reduce;
using colossalAI::funcs::ReduceType;
using colossalAI::cuda::utils::copy_vector;
using colossalAI::cuda::utils::copy;


/*
Expand Down Expand Up @@ -87,8 +87,8 @@ __global__ void scaled_masked_softmax_warp_forward(

if (element_index < batch_element_count) {
int itr_idx = i * element_count + it * WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
copy<input_t, ELEMENTS_PER_LDG_STG>(src + itr_idx, temp_data);
copy<uint8_t, ELEMENTS_PER_LDG_STG>(mask + itr_idx, temp_mask);

#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
Expand Down Expand Up @@ -144,8 +144,8 @@ __global__ void scaled_masked_softmax_warp_forward(
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
copy<output_t, ELEMENTS_PER_LDG_STG>(
out, dst + i * element_count + it * WARP_SIZE);
} else {
break;
}
Expand Down Expand Up @@ -200,10 +200,10 @@ __global__ void scaled_masked_softmax_warp_backward(
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_grad, grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_output, output + i * element_count + it * WARP_SIZE);
copy<input_t, ELEMENTS_PER_LDG_STG>(
grad + i * element_count + it * WARP_SIZE, temp_grad);
copy<input_t, ELEMENTS_PER_LDG_STG>(
output + i * element_count + it * WARP_SIZE, temp_output);

#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
Expand Down Expand Up @@ -245,8 +245,8 @@ __global__ void scaled_masked_softmax_warp_backward(
(output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count + it * WARP_SIZE, out);
copy<output_t, ELEMENTS_PER_LDG_STG>(
out, gradInput + i * element_count + it * WARP_SIZE);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ using colossalAI::funcs::UnaryOpFunctor;
using colossalAI::funcs::UnaryOpType;
using colossalAI::funcs::warp_reduce;
using colossalAI::funcs::ReduceType;
using colossalAI::cuda::utils::copy_vector;
using colossalAI::cuda::utils::copy_zero_vector;
using colossalAI::cuda::utils::copy;
using colossalAI::cuda::utils::copy_zero;

/*
* Extended softmax (from native aten pytorch) with following additional
Expand Down Expand Up @@ -75,8 +75,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;

if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_data, src + i * element_count * stride + it * WARP_SIZE);
copy<input_t, ELEMENTS_PER_LDG_STG>(
src + i * element_count * stride + it * WARP_SIZE, temp_data);

#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
Expand Down Expand Up @@ -140,10 +140,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count * stride + it * WARP_SIZE, out);
copy<output_t, ELEMENTS_PER_LDG_STG>(
out, dst + i * element_count * stride + it * WARP_SIZE);
} else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(
copy_zero<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
Expand Down Expand Up @@ -199,10 +199,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_output, output + i * element_count * stride + it * WARP_SIZE);
copy<input_t, ELEMENTS_PER_LDG_STG>(
grad + i * element_count * stride + it * WARP_SIZE, temp_grad);
copy<input_t, ELEMENTS_PER_LDG_STG>(
output + i * element_count * stride + it * WARP_SIZE, temp_output);

#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
Expand Down Expand Up @@ -248,8 +248,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
(output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count * stride + it * WARP_SIZE, out);
copy<output_t, ELEMENTS_PER_LDG_STG>(
out, gradInput + i * element_count * stride + it * WARP_SIZE);
}
}
}
Expand Down
19 changes: 1 addition & 18 deletions extensions/csrc/kernel/cuda/utils/vec_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,8 @@ namespace colossalAI {
namespace cuda {
namespace utils {

// Note(LiuYang): Depreciated
template <typename T, int VecSize>
__device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
}

template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
*(reinterpret_cast<float4 *>(dst + 4)) =
*(reinterpret_cast<const float4 *>(src + 4));
}

// Note(LiuYang): Depreciated
template <typename T, int VecSize>
__device__ __inline__ void copy_zero_vector(T *dst) {
__device__ __inline__ void copy_zero(T *dst) {
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
}
Expand Down

0 comments on commit 121d7ad

Please sign in to comment.