diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index c96b911c08..389661b3bb 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -446,6 +446,7 @@ set(fbgemm_gpu_sources_static_cpu codegen/training/backward/embedding_backward_dense_host_cpu.cpp codegen/utils/embedding_bounds_check_host_cpu.cpp src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp + src/permute_pooled_embedding_ops/permute_multi_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp diff --git a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py index a06214b981..5ea9472858 100644 --- a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py +++ b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py @@ -19,10 +19,16 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu" + ) try: torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" + ) except OSError: # This is for forward compatibility (new torch.package + old backend) # We should be able to remove it after this diff is picked up by all backend diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h new file mode 100644 index 0000000000..d79bf8ae8c --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +class PermuteMultiEmbeddingOp + : public torch::autograd::Function { + public: + static variable_list forward( + AutogradContext* ctx, + const at::TensorList& pooled_embs_list, + const Tensor& permutes, + const std::vector& lengths, + const Tensor& input_lengths, + const Tensor& output_lengths); + + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output); +}; + +std::vector permute_multi_embedding_cpu( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const std::vector& lengths, + const Tensor& input_lengths, + const Tensor& output_lengths, + const bool& reverse_permute); + +std::vector permute_multi_embedding_meta( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const std::vector& lengths, + const Tensor& input_lengths, + const Tensor& output_lengths, + const bool& reverse_permute); + +std::vector permute_multi_embedding_gpu( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const std::vector& lengths, + const Tensor& input_lengths, + const Tensor& output_lengths, + const bool& reverse_permute); +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp new file mode 100644 index 0000000000..0711b65150 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" +#include +#include + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +variable_list PermuteMultiEmbeddingOp::forward( + AutogradContext* ctx, + const at::TensorList& pooled_embs, + const Tensor& permutes, + const std::vector& lengths, + const Tensor& input_lengths, + const Tensor& output_lengths) { + ctx->saved_data["permutes"] = permutes; + ctx->saved_data["input_lengths"] = input_lengths; + ctx->saved_data["output_lengths"] = output_lengths; + + std::vector inv_lengths; + inv_lengths.reserve(pooled_embs.size()); + for (const auto i : c10::irange(pooled_embs.size())) { + inv_lengths.push_back(pooled_embs[i].size(1)); + } + ctx->saved_data["inv_lengths"] = inv_lengths; + + /* + select the correct dispatched (cpu/gpu) forward function + the cpu/gup function needs to be registered in the dispatcher, + e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc. + */ + const auto permute_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") + .typed(); + + return permute_op.call( + pooled_embs, permutes, lengths, input_lengths, output_lengths, false); +} + +variable_list PermuteMultiEmbeddingOp::backward( + AutogradContext* ctx, + variable_list grad_output) { + const auto permutes = ctx->saved_data["permutes"].toTensor(); + const auto input_lengths = ctx->saved_data["input_lengths"].toTensor(); + const auto output_lengths = ctx->saved_data["output_lengths"].toTensor(); + const auto inv_lengths = ctx->saved_data["inv_lengths"].toIntVector(); + /* + select the correct dispatched (cpu/gpu) backward function + the cpu/gup function needs to be registered in the dispatcher, + e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc. + */ + const auto permute_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") + .typed(); + auto grad_input = permute_op.call( + grad_output, permutes, inv_lengths, output_lengths, input_lengths, true); + grad_input.push_back(torch::autograd::Variable()); // permutes + grad_input.push_back(torch::autograd::Variable()); // lengths + grad_input.push_back(torch::autograd::Variable()); // input_lengths + grad_input.push_back(torch::autograd::Variable()); // output_lengths + return grad_input; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu new file mode 100644 index 0000000000..13005b5806 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu @@ -0,0 +1,210 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +// Kernerl for permute pooled embedding op. +// This kernel is moving D elements per warp. +template +__global__ void permute_multi_embs_kernel( + const scalar_t** __restrict__ inputs, + const scalar_t** __restrict__ outputs, + const int64_t* __restrict__ permutes, + const int64_t* __restrict__ input_lengths, + const int64_t* __restrict__ output_lengths, + const int64_t batch_size, + const int64_t permute_size, + const bool reverse_permute) { + // workers in a warp handle a feature + const int32_t worker_id = threadIdx.x % warpSize; + const int32_t worker_size = warpSize; + const int32_t permute_id = + blockIdx.x * (blockDim.x / warpSize) + threadIdx.x / warpSize; + const int32_t batch_id = blockIdx.y + gridDim.y * blockIdx.z; + if (batch_id >= batch_size) { + return; + } + if (permute_id >= permute_size) { + return; + } + + // parse permutes + const int64_t params = 6; + int64_t in_tensor, out_tensor, in_start, out_start, length, jump; + if (reverse_permute) { + out_tensor = permutes[params * permute_id]; + in_tensor = permutes[params * permute_id + 1]; + out_start = permutes[params * permute_id + 2]; + in_start = permutes[params * permute_id + 3]; + } else { + in_tensor = permutes[params * permute_id]; + out_tensor = permutes[params * permute_id + 1]; + in_start = permutes[params * permute_id + 2]; + out_start = permutes[params * permute_id + 3]; + } + length = permutes[params * permute_id + 4]; + jump = permutes[params * permute_id + 5]; + + if (worker_id >= length) { + return; + } + if (reverse_permute && jump < 0) { + return; + } + + // locate the batch_id + int64_t in_length = input_lengths[in_tensor]; + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; + input_ptr += batch_id * in_length; + + int64_t out_length = output_lengths[out_tensor]; + scalar_t* output_ptr = (scalar_t*)outputs[out_tensor]; + output_ptr += batch_id * out_length; + + // printf( // debug print + // "input_tensors[%ld][%ld][%d] = %f\n", + // in_tensor, + // batch_id, + // in_start + worker_id, + // input_ptr[in_start + worker_id]); + if (fbgemm_gpu::is_aligned>( + &output_ptr[out_start]) && + fbgemm_gpu::is_aligned>( + &input_ptr[in_start])) { + const int32_t vec_size = 4; + const int32_t loop_end = length / (vec_size) * (vec_size); + for (int32_t i = worker_id * vec_size; i < loop_end; + i += worker_size * vec_size) { + fbgemm_gpu::Vec4T::copy( + &input_ptr[in_start + i], &output_ptr[out_start + i]); + } + // Use elementwise access for the last incomplete vector. + for (int32_t i = loop_end + worker_id; i < length; i += worker_size) { + output_ptr[out_start + i] = input_ptr[in_start + i]; + } + } else { // Fallback if not aligned. + for (int32_t i = worker_id; i < length; i += worker_size) { + output_ptr[out_start + i] = input_ptr[in_start + i]; + } + } + + // for reverse_permute (backward) with jump + while (reverse_permute && jump > 0 && jump < permute_size) { + in_tensor = permutes[params * jump + 1]; + in_start = permutes[params * jump + 3]; + length = permutes[params * jump + 4]; + jump = -permutes[params * jump + 5]; + + int64_t in_length = input_lengths[in_tensor]; + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; + input_ptr += batch_id * in_length; + + for (int32_t i = worker_id; i < length; i += worker_size) { + output_ptr[out_start + i] += input_ptr[in_start + i]; + } + } +} + +template +const scalar_t** tensors_ptr(const at::TensorList& tensors) { + auto size = tensors.size(); + std::vector tp_vec; + tp_vec.reserve(size); + + const scalar_t** ptrVec; + cudaMalloc(&ptrVec, size * sizeof(scalar_t*)); + + for (int32_t i = 0; i < size; i++) { + tp_vec.push_back(tensors[i].data_ptr()); + } + cudaMemcpy( + ptrVec, tp_vec.data(), size * sizeof(scalar_t*), cudaMemcpyHostToDevice); + return ptrVec; +} + +std::vector permute_multi_embedding_gpu( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const std::vector& lengths, + const Tensor& input_lengths, + const Tensor& output_lengths, + const bool& reverse_permute = false) { + int64_t num_of_input_tensors = pooled_embs.size(); + int64_t num_of_output_tensors = lengths.size(); + int64_t batch_size = pooled_embs[0].size(0); + int64_t permute_size = permutes.size(0); + std::vector inputs; + inputs.reserve(num_of_input_tensors); + for (int32_t i = 0; i < num_of_input_tensors; i++) { + TENSORS_ON_SAME_DEVICE(pooled_embs[i], permutes); + // how does this contiguous impact performance in forward and backward? + inputs.push_back(pooled_embs[i].contiguous()); + TENSORS_ON_SAME_DEVICE(inputs[i], permutes); + } + TENSORS_ON_SAME_DEVICE(input_lengths, permutes); + TENSORS_ON_SAME_DEVICE(input_lengths, output_lengths); + + // initiate output tensors + std::vector outputs; + outputs.reserve(num_of_output_tensors); + for (int32_t i = 0; i < num_of_output_tensors; i++) { + Tensor output = + at::empty({batch_size, lengths[i]}, pooled_embs[0].options()); + outputs.push_back(output); + } + + // This kernel is moving D elements per warp. + // We are launching ( div_round_up(T, warp_per_block), B ) blocks. + // The grid z dimension is also used by batch_size in case it's greater than + // 65535. + const int32_t warp_per_block = + fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize; + const int32_t max_grid_dim_y = + 32768; // The CUDA maximum is 65535, not a power of 2. + const dim3 threads(fbgemm_gpu::kMaxThreads); + const dim3 blocks( + fbgemm_gpu::div_round_up(permute_size, warp_per_block), + std::min(static_cast(batch_size), max_grid_dim_y), + (batch_size + max_grid_dim_y - 1) / max_grid_dim_y); + + FBGEMM_DISPATCH_FLOATING_TYPES( + pooled_embs[0].scalar_type(), "permute_multi_embedding", [&] { + const scalar_t** input_ptr = tensors_ptr(inputs); + const scalar_t** output_ptr = tensors_ptr(outputs); + permute_multi_embs_kernel + <<>>( + (const scalar_t**)input_ptr, + (const scalar_t**)output_ptr, + (const int64_t*)permutes.data_ptr(), + (const int64_t*)input_lengths.data_ptr(), + (const int64_t*)output_lengths.data_ptr(), + batch_size, + permute_size, + reverse_permute); + C10_CUDA_KERNEL_LAUNCH_CHECK(); // why this failed? + cudaFree(input_ptr); + cudaFree(output_ptr); + }); + return outputs; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp new file mode 100644 index 0000000000..7aa353c01f --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +std::vector permute_multi_embedding_cpu( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const std::vector& lengths, + const Tensor& input_lengths, + const Tensor& output_lengths, + const bool& reverse_permute) { + int64_t num_output_tensors = lengths.size(); + int64_t batch_size = pooled_embs[0].size(0); + + std::vector outputs; + outputs.reserve(num_output_tensors); + for (const auto i : c10::irange(num_output_tensors)) { + outputs.push_back( + at::empty({batch_size, lengths[i]}, pooled_embs[0].options())); + } + + int64_t in_tensor, out_tensor, in_start, out_start, length, jump; + for (const auto i : c10::irange(permutes.size(0))) { + if (reverse_permute) { + out_tensor = permutes[i][0].item(); + in_tensor = permutes[i][1].item(); + out_start = permutes[i][2].item(); + in_start = permutes[i][3].item(); + jump = permutes[i][5].item(); + } else { + in_tensor = permutes[i][0].item(); + out_tensor = permutes[i][1].item(); + in_start = permutes[i][2].item(); + out_start = permutes[i][3].item(); + } + length = permutes[i][4].item(); + if (reverse_permute && jump < 0) { + for (const auto b : c10::irange(batch_size)) { + for (const auto j : c10::irange(length)) { + outputs[out_tensor][b][j + out_start] += + pooled_embs[in_tensor][b][j + in_start]; + } + } + } else { + for (const auto b : c10::irange(batch_size)) { + for (const auto j : c10::irange(length)) { + outputs[out_tensor][b][j + out_start] = + pooled_embs[in_tensor][b][j + in_start]; + } + } + } + } + + return outputs; +} + +std::vector permute_multi_embedding_meta( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const std::vector& lengths, + const Tensor& input_lengths, + const Tensor& output_lengths, + const bool& reverse) { + int64_t num_output_tensors = lengths.size(); + int64_t batch_size = pooled_embs[0].size(0); + + std::vector output; + output.reserve(num_output_tensors); + for (const auto i : c10::irange(num_output_tensors)) { + output.push_back( + at::empty({batch_size, lengths[i]}, pooled_embs[0].options())); + } + return output; +} + +std::vector permute_multi_embedding_autograd( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const std::vector& lengths, + const Tensor& input_lengths, + const Tensor& output_lengths) { + return PermuteMultiEmbeddingOp::apply( + pooled_embs, permutes, lengths, input_lengths, output_lengths); +} + +} // namespace fbgemm_gpu + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // register the forward function for internal (autograd) usage + m.def( + "permute_multi_embedding_function(Tensor[] pooled_embs, Tensor permutes, SymInt[] lengths, Tensor in_lengths, Tensor out_lengths, bool reverse=False) -> Tensor[]", + {PT2_COMPLIANT_TAG}); + + // register the main function for external usage + m.def( + "permute_multi_embedding(Tensor[] pooled_embs, Tensor permutes, SymInt[] lengths, Tensor in_lengths, Tensor out_lengths) -> Tensor[]", + {PT2_COMPLIANT_TAG}); + + // dispatch the forward function to CPU for internal (autograd) usage + DISPATCH_TO_CPU( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_cpu); + + // dispatch the forward function to CPU for internal (autograd) usage + DISPATCH_TO_META( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_meta); + + // dispath the main function to Autograd for external usage + DISPATCH_TO_AUTOGRAD( + "permute_multi_embedding", fbgemm_gpu::permute_multi_embedding_autograd); +} diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp new file mode 100644 index 0000000000..4a9ffbf8a4 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp @@ -0,0 +1,18 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +// do we really need this file, can we put this into +// permute_multi_embedding_ops.cu? +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // dispatch the forward function to GPU for internal (autograd) usage + DISPATCH_TO_CUDA( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_gpu); +}