diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index c96b911c08..389661b3bb 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -446,6 +446,7 @@ set(fbgemm_gpu_sources_static_cpu codegen/training/backward/embedding_backward_dense_host_cpu.cpp codegen/utils/embedding_bounds_check_host_cpu.cpp src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp + src/permute_pooled_embedding_ops/permute_multi_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp diff --git a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py index a06214b981..5ea9472858 100644 --- a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py +++ b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py @@ -19,10 +19,16 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu" + ) try: torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" + ) except OSError: # This is for forward compatibility (new torch.package + old backend) # We should be able to remove it after this diff is picked up by all backend diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h new file mode 100644 index 0000000000..c2f1a704a0 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +class PermuteMultiEmbeddingOp + : public torch::autograd::Function { + public: + static variable_list forward( + AutogradContext* ctx, + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths); + + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output); +}; + +std::vector permute_multi_embedding_cpu( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute); + +std::vector permute_multi_embedding_meta( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute); + +std::vector permute_multi_embedding_gpu( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute); +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp new file mode 100644 index 0000000000..90b9403d12 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" +#include +#include + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +variable_list PermuteMultiEmbeddingOp::forward( + AutogradContext* ctx, + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths) { + ctx->saved_data["permutes"] = permutes; + ctx->saved_data["in_lengths"] = in_lengths; + ctx->saved_data["out_lengths"] = out_lengths; + + /* + select the correct dispatched (cpu/gpu) forward function + the cpu/gup function needs to be registered in the dispatcher, + e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc. + */ + const auto permute_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") + .typed(); + + return permute_op.call(pooled_embs, permutes, in_lengths, out_lengths, false); +} + +variable_list PermuteMultiEmbeddingOp::backward( + AutogradContext* ctx, + variable_list grad_output) { + const auto permutes = ctx->saved_data["permutes"].toIntVector(); + const auto in_lengths = ctx->saved_data["in_lengths"].toIntVector(); + const auto out_lengths = ctx->saved_data["out_lengths"].toIntVector(); + + /* + select the correct dispatched (cpu/gpu) backward function + the cpu/gup function needs to be registered in the dispatcher, + e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc. + */ + const auto permute_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") + .typed(); + auto grad_input = + permute_op.call(grad_output, permutes, out_lengths, in_lengths, true); + grad_input.push_back(torch::autograd::Variable()); // permutes + grad_input.push_back(torch::autograd::Variable()); // in_lengths + grad_input.push_back(torch::autograd::Variable()); // out_lengths + return grad_input; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu new file mode 100644 index 0000000000..26467e73dd --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu @@ -0,0 +1,232 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +// Kernerl for permute pooled embedding op. +// This kernel is moving D elements per warp. +template +__global__ void permute_multi_embs_kernel( + const scalar_t** __restrict__ inputs, + const scalar_t** __restrict__ outputs, + const int64_t* __restrict__ permutes, + const int64_t* __restrict__ input_lengths, + const int64_t* __restrict__ output_lengths, + const int64_t batch_size, + const int64_t permute_size, + const bool reverse_permute) { + // workers in a warp handle a feature + const int32_t worker_id = threadIdx.x % warpSize; + const int32_t worker_size = warpSize; + const int32_t permute_id = + blockIdx.x * (blockDim.x / warpSize) + threadIdx.x / warpSize; + const int32_t batch_id = blockIdx.y + gridDim.y * blockIdx.z; + if (batch_id >= batch_size) { + return; + } + if (permute_id >= permute_size) { + return; + } + + // parse permutes + const int64_t params = 6; + int64_t in_tensor, out_tensor, in_start, out_start, length, jump; + if (reverse_permute) { + out_tensor = permutes[params * permute_id]; + in_tensor = permutes[params * permute_id + 1]; + out_start = permutes[params * permute_id + 2]; + in_start = permutes[params * permute_id + 3]; + } else { + in_tensor = permutes[params * permute_id]; + out_tensor = permutes[params * permute_id + 1]; + in_start = permutes[params * permute_id + 2]; + out_start = permutes[params * permute_id + 3]; + } + length = permutes[params * permute_id + 4]; + jump = permutes[params * permute_id + 5]; + + if (worker_id >= length) { + return; + } + if (reverse_permute && jump < 0) { + return; + } + + // locate the batch_id + int64_t in_length = input_lengths[in_tensor]; + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; + input_ptr += batch_id * in_length; + + int64_t out_length = output_lengths[out_tensor]; + scalar_t* output_ptr = (scalar_t*)outputs[out_tensor]; + output_ptr += batch_id * out_length; + + // printf( // debug print + // "input_tensors[%ld][%ld][%d] = %f\n", + // in_tensor, + // batch_id, + // in_start + worker_id, + // input_ptr[in_start + worker_id]); + if (fbgemm_gpu::is_aligned>( + &output_ptr[out_start]) && + fbgemm_gpu::is_aligned>( + &input_ptr[in_start])) { + const int32_t vec_size = 4; + const int32_t loop_end = length / (vec_size) * (vec_size); + for (int32_t i = worker_id * vec_size; i < loop_end; + i += worker_size * vec_size) { + fbgemm_gpu::Vec4T::copy( + &input_ptr[in_start + i], &output_ptr[out_start + i]); + } + // Use elementwise access for the last incomplete vector. + for (int32_t i = loop_end + worker_id; i < length; i += worker_size) { + output_ptr[out_start + i] = input_ptr[in_start + i]; + } + } else { // Fallback if not aligned. + for (int32_t i = worker_id; i < length; i += worker_size) { + output_ptr[out_start + i] = input_ptr[in_start + i]; + } + } + + // for reverse_permute (backward) with jump + while (reverse_permute && jump > 0 && jump < permute_size) { + in_tensor = permutes[params * jump + 1]; + in_start = permutes[params * jump + 3]; + length = permutes[params * jump + 4]; + jump = -permutes[params * jump + 5]; + + int64_t in_length = input_lengths[in_tensor]; + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; + input_ptr += batch_id * in_length; + + for (int32_t i = worker_id; i < length; i += worker_size) { + output_ptr[out_start + i] += input_ptr[in_start + i]; + } + } +} + +template +Tensor from_vec(const std::vector input) { + const auto int_opts = + torch::TensorOptions().dtype(torch::kInt64).pinned_memory(true); + Tensor output = at::empty({static_cast(input.size())}, int_opts); + // Ensure that output is contiguous + TORCH_CHECK(output.is_contiguous()); + std::memcpy( + output.data_ptr(), input.data(), input.size() * sizeof(index_t)); + return output; +} + +template +Tensor tensors_ptr(const at::TensorList& tensors) { + auto size = tensors.size(); + Tensor ptr_tensor = at::empty( + {static_cast(size * sizeof(scalar_t*))}, + at::TensorOptions().dtype(tensors[0].scalar_type()).pinned_memory(true)); + + // Ensure that ptr_tensor is contiguous + TORCH_CHECK(ptr_tensor.is_contiguous()); + auto tp = reinterpret_cast(ptr_tensor.data_ptr()); + for (int32_t i = 0; i < tensors.size(); i++) { + tp[i] = tensors[i].data_ptr(); + } + // Ensure that ptr_tensor is contiguous + TORCH_CHECK(ptr_tensor.is_contiguous()); + return ptr_tensor; +} + +std::vector permute_multi_embedding_gpu( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute) { + const int64_t permute_param = 6; + int64_t num_of_input_tensors = in_lengths.size(); + int64_t num_of_output_tensors = out_lengths.size(); + int64_t batch_size = pooled_embs[0].size(0); + int64_t permute_size = permutes.size() / permute_param; + + // check input tensors + std::vector inputs; + inputs.reserve(pooled_embs.size()); + for (int32_t i = 0; i < num_of_input_tensors; i++) { + Tensor cont_tensor = pooled_embs[i].contiguous(); + inputs.push_back(cont_tensor); + TENSORS_ON_SAME_DEVICE(cont_tensor, pooled_embs[i]); + TENSORS_ON_SAME_DEVICE(pooled_embs[i], pooled_embs[0]); + } + + // initiate output tensors + std::vector outputs; + outputs.reserve(num_of_output_tensors); + for (int32_t i = 0; i < num_of_output_tensors; i++) { + Tensor output = + at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options()); + outputs.push_back(output); + } + + auto permutes_tensor = from_vec(permutes); + auto in_lengths_tensor = from_vec(in_lengths); + auto out_lengths_tensor = from_vec(out_lengths); + + auto device = pooled_embs[0].device(); + permutes_tensor = permutes_tensor.to(device, /*non_blocking=*/true); + in_lengths_tensor = in_lengths_tensor.to(device, /*non_blocking=*/true); + out_lengths_tensor = out_lengths_tensor.to(device, /*non_blocking=*/true); + + // This kernel is moving D elements per warp. + // We are launching ( div_round_up(T, warp_per_block), B ) blocks. + // The grid z dimension is also used by batch_size in case it's greater than + // 65535. + const int32_t warp_per_block = + fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize; + const int32_t max_grid_dim_y = + 32768; // The CUDA maximum is 65535, not a power of 2. + const dim3 threads(fbgemm_gpu::kMaxThreads); + const dim3 blocks( + fbgemm_gpu::div_round_up(permute_size, warp_per_block), + std::min(static_cast(batch_size), max_grid_dim_y), + (batch_size + max_grid_dim_y - 1) / max_grid_dim_y); + + FBGEMM_DISPATCH_FLOATING_TYPES( + pooled_embs[0].scalar_type(), "permute_multi_embedding", [&] { + Tensor in_tensor = tensors_ptr(inputs); + Tensor out_tensor = tensors_ptr(outputs); + in_tensor = in_tensor.to(device, /*non_blocking=*/true); + out_tensor = out_tensor.to(device, /*non_blocking=*/true); + permute_multi_embs_kernel + <<>>( + (const scalar_t**)in_tensor.data_ptr(), + (const scalar_t**)out_tensor.data_ptr(), + permutes_tensor.data_ptr(), + in_lengths_tensor.data_ptr(), + out_lengths_tensor.data_ptr(), + batch_size, + permute_size, + reverse_permute); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return outputs; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp new file mode 100644 index 0000000000..59937cf893 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +std::vector permute_multi_embedding_cpu( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute) { + int64_t batch_size = pooled_embs[0].size(0); + + std::vector outputs; + outputs.reserve(out_lengths.size()); + for (const auto i : c10::irange(out_lengths.size())) { + outputs.push_back( + at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options())); + } + + int64_t in_tensor, out_tensor, in_start, out_start, length, jump; + const int64_t param = 6; + for (const auto i : c10::irange(permutes.size() / param)) { + if (reverse_permute) { + out_tensor = permutes[i * param]; + in_tensor = permutes[i * param + 1]; + out_start = permutes[i * param + 2]; + in_start = permutes[i * param + 3]; + jump = permutes[i * param + 5]; + } else { + in_tensor = permutes[i * param]; + out_tensor = permutes[i * param + 1]; + in_start = permutes[i * param + 2]; + out_start = permutes[i * param + 3]; + } + length = permutes[i * param + 4]; + if (reverse_permute && jump < 0) { + for (const auto b : c10::irange(batch_size)) { + for (const auto j : c10::irange(length)) { + outputs[out_tensor][b][j + out_start] += + pooled_embs[in_tensor][b][j + in_start]; + } + } + } else { + for (const auto b : c10::irange(batch_size)) { + for (const auto j : c10::irange(length)) { + outputs[out_tensor][b][j + out_start] = + pooled_embs[in_tensor][b][j + in_start]; + } + } + } + } + return outputs; +} + +std::vector permute_multi_embedding_meta( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute) { + int64_t batch_size = pooled_embs[0].size(0); + + std::vector outputs; + outputs.reserve(out_lengths.size()); + for (const auto i : c10::irange(out_lengths.size())) { + outputs.push_back( + at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options())); + } + return outputs; +} + +std::vector permute_multi_embedding_autograd( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths) { + return PermuteMultiEmbeddingOp::apply( + pooled_embs, permutes, in_lengths, out_lengths); +} + +} // namespace fbgemm_gpu + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // register the forward function for internal (autograd) usage + m.def( + "permute_multi_embedding_function(Tensor[] pooled_embs, int[] permutes, SymInt[] in_lengths, SymInt[] out_lengths, bool reverse=False) -> Tensor[]", + {PT2_COMPLIANT_TAG}); + + // register the main function for external usage + m.def( + "permute_multi_embedding(Tensor[] pooled_embs, int[] permutes, SymInt[] in_lengths, SymInt[] out_lengths) -> Tensor[]", + {PT2_COMPLIANT_TAG}); + + // dispatch the forward function to CPU for internal (autograd) usage + DISPATCH_TO_CPU( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_cpu); + + // dispatch the forward function to CPU for internal (autograd) usage + DISPATCH_TO_META( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_meta); + + // dispath the main function to Autograd for external usage + DISPATCH_TO_AUTOGRAD( + "permute_multi_embedding", fbgemm_gpu::permute_multi_embedding_autograd); + + // dispath the main function to Autograd for external usage + DISPATCH_TO_CUDA( + "permute_multi_embedding", fbgemm_gpu::permute_multi_embedding_autograd); +} diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp new file mode 100644 index 0000000000..4f137f6926 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp @@ -0,0 +1,16 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // dispatch the forward function to GPU for internal (autograd) usage + DISPATCH_TO_CUDA( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_gpu); +}