From 121d7ad629c746e52a96ec53d6e26c0194016a03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 14 May 2024 14:35:33 +0800 Subject: [PATCH] [Inference] Delete duplicated copy_vector (#5716) --- .../cuda/decode_kv_cache_memcpy_kernel.cu | 1 - .../cuda/fused_rotary_emb_and_cache_kernel.cu | 1 - .../kernel/cuda/get_cos_and_sin_kernel.cu | 6 ++--- .../cuda/scaled_masked_softmax_kernel.cu | 22 ++++++++-------- ...aled_upper_triang_masked_softmax_kernel.cu | 26 +++++++++---------- extensions/csrc/kernel/cuda/utils/vec_copy.h | 19 +------------- 6 files changed, 28 insertions(+), 47 deletions(-) diff --git a/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu index 19ea5bb8aca2..3d011a4e48ff 100644 --- a/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu @@ -5,7 +5,6 @@ #include "funcs/cast_functor.h" #include "common/micros.h" -using colossalAI::cuda::utils::copy_vector; using colossalAI::cuda::utils::get_vec_size; using colossalAI::cuda::utils::copy; using colossalAI::funcs::CastFunctor; diff --git a/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu index 4f96c7c42c1f..6dc9495ef7d9 100644 --- a/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -8,7 +8,6 @@ #include "funcs/cast_functor.h" #include "funcs/binary_functor.h" -using colossalAI::cuda::utils::copy_vector; using colossalAI::cuda::utils::get_vec_size; using colossalAI::cuda::utils::copy; using colossalAI::funcs::CastFunctor; diff --git a/extensions/csrc/kernel/cuda/get_cos_and_sin_kernel.cu b/extensions/csrc/kernel/cuda/get_cos_and_sin_kernel.cu index 9c78666e68bd..d5fda83ebb56 100644 --- a/extensions/csrc/kernel/cuda/get_cos_and_sin_kernel.cu +++ b/extensions/csrc/kernel/cuda/get_cos_and_sin_kernel.cu @@ -4,7 +4,7 @@ #include "utils/vec_copy.h" #include "common/micros.h" -using colossalAI::cuda::utils::copy_vector; +using colossalAI::cuda::utils::copy; using colossalAI::cuda::utils::get_vec_size; @@ -23,8 +23,8 @@ __device__ void apply_cos_and_sin_memcopy( int begin_id = threadIdx.x * VecSize; for (; begin_id <= head_dim - VecSize; begin_id += blockDim.x){ - copy_vector(cos + dest_offset_id + begin_id, cos_cache_ptr + src_offset_id + begin_id); - copy_vector(sin + dest_offset_id + begin_id, sin_cache_ptr + src_offset_id + begin_id); + copy(cos_cache_ptr + src_offset_id + begin_id, cos + dest_offset_id + begin_id); + copy(sin_cache_ptr + src_offset_id + begin_id, sin + dest_offset_id + begin_id); } if (!Aligned) { diff --git a/extensions/csrc/kernel/cuda/scaled_masked_softmax_kernel.cu b/extensions/csrc/kernel/cuda/scaled_masked_softmax_kernel.cu index db9a2bbd609a..00455897ebb3 100644 --- a/extensions/csrc/kernel/cuda/scaled_masked_softmax_kernel.cu +++ b/extensions/csrc/kernel/cuda/scaled_masked_softmax_kernel.cu @@ -23,7 +23,7 @@ using colossalAI::funcs::UnaryOpFunctor; using colossalAI::funcs::UnaryOpType; using colossalAI::funcs::warp_reduce; using colossalAI::funcs::ReduceType; -using colossalAI::cuda::utils::copy_vector; +using colossalAI::cuda::utils::copy; /* @@ -87,8 +87,8 @@ __global__ void scaled_masked_softmax_warp_forward( if (element_index < batch_element_count) { int itr_idx = i * element_count + it * WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); + copy(src + itr_idx, temp_data); + copy(mask + itr_idx, temp_mask); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { @@ -144,8 +144,8 @@ __global__ void scaled_masked_softmax_warp_forward( for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); + copy( + out, dst + i * element_count + it * WARP_SIZE); } else { break; } @@ -200,10 +200,10 @@ __global__ void scaled_masked_softmax_warp_backward( for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector( - temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector( - temp_output, output + i * element_count + it * WARP_SIZE); + copy( + grad + i * element_count + it * WARP_SIZE, temp_grad); + copy( + output + i * element_count + it * WARP_SIZE, temp_output); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { @@ -245,8 +245,8 @@ __global__ void scaled_masked_softmax_warp_backward( (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } - copy_vector( - gradInput + i * element_count + it * WARP_SIZE, out); + copy( + out, gradInput + i * element_count + it * WARP_SIZE); } } } diff --git a/extensions/csrc/kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu b/extensions/csrc/kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu index db90916f3894..42d14b423749 100644 --- a/extensions/csrc/kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu +++ b/extensions/csrc/kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu @@ -23,8 +23,8 @@ using colossalAI::funcs::UnaryOpFunctor; using colossalAI::funcs::UnaryOpType; using colossalAI::funcs::warp_reduce; using colossalAI::funcs::ReduceType; -using colossalAI::cuda::utils::copy_vector; -using colossalAI::cuda::utils::copy_zero_vector; +using colossalAI::cuda::utils::copy; +using colossalAI::cuda::utils::copy_zero; /* * Extended softmax (from native aten pytorch) with following additional @@ -75,8 +75,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector( - temp_data, src + i * element_count * stride + it * WARP_SIZE); + copy( + src + i * element_count * stride + it * WARP_SIZE, temp_data); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { @@ -140,10 +140,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( out[element] = 0; } } - copy_vector( - dst + i * element_count * stride + it * WARP_SIZE, out); + copy( + out, dst + i * element_count * stride + it * WARP_SIZE); } else if (element_index < element_count) { - copy_zero_vector( + copy_zero( dst + i * element_count * stride + it * WARP_SIZE); } else { break; @@ -199,10 +199,10 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector( - temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector( - temp_output, output + i * element_count * stride + it * WARP_SIZE); + copy( + grad + i * element_count * stride + it * WARP_SIZE, temp_grad); + copy( + output + i * element_count * stride + it * WARP_SIZE, temp_output); #pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { @@ -248,8 +248,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } - copy_vector( - gradInput + i * element_count * stride + it * WARP_SIZE, out); + copy( + out, gradInput + i * element_count * stride + it * WARP_SIZE); } } } diff --git a/extensions/csrc/kernel/cuda/utils/vec_copy.h b/extensions/csrc/kernel/cuda/utils/vec_copy.h index 6c099df695f9..465703a743a8 100644 --- a/extensions/csrc/kernel/cuda/utils/vec_copy.h +++ b/extensions/csrc/kernel/cuda/utils/vec_copy.h @@ -8,25 +8,8 @@ namespace colossalAI { namespace cuda { namespace utils { -// Note(LiuYang): Depreciated template -__device__ __inline__ void copy_vector(T *dst, const T *src) { - using VT = typename common::VecTypeTrait::Type; - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); -} - -template <> -__device__ __inline__ void copy_vector(float *dst, const float *src) { - // Since the maximum memory alignment length is 128 bits, we choose float4 - // here. - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); - *(reinterpret_cast(dst + 4)) = - *(reinterpret_cast(src + 4)); -} - -// Note(LiuYang): Depreciated -template -__device__ __inline__ void copy_zero_vector(T *dst) { +__device__ __inline__ void copy_zero(T *dst) { using VT = typename common::VecTypeTrait::Type; *(reinterpret_cast(dst)) = funcs::CastFunctor()(0.0f); }