From f5d6d394740cf0b3096358c4bdf9e62c0bebc94b Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Sat, 22 Jun 2024 09:29:18 -0700 Subject: [PATCH] implementation of fbgemm op - permute_multi_embedding (#2738) Summary: X-link: https://github.com/pytorch/torchrec/pull/2120 # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # notes * this diff focuses on the implemenation and test of the operator * performance analysis and benchmark are in the next diff # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding( values, permutes, in_lengths, out_lengths ) ``` * permutes ``` permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clearly see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents a key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors Differential Revision: D57055616 --- fbgemm_gpu/FbgemmGpu.cmake | 1 + .../permute_pooled_embedding_modules.py | 6 + .../permute_multi_embedding_function.h | 64 +++++ .../permute_multi_embedding_function.cpp | 66 +++++ .../permute_multi_embedding_ops.cu | 232 ++++++++++++++++++ .../permute_multi_embedding_ops_cpu.cpp | 123 ++++++++++ .../permute_multi_embedding_ops_gpu.cpp | 16 ++ 7 files changed, 508 insertions(+) create mode 100644 fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h create mode 100644 fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp create mode 100644 fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu create mode 100644 fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp create mode 100644 fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index c96b911c08..389661b3bb 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -446,6 +446,7 @@ set(fbgemm_gpu_sources_static_cpu codegen/training/backward/embedding_backward_dense_host_cpu.cpp codegen/utils/embedding_bounds_check_host_cpu.cpp src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp + src/permute_pooled_embedding_ops/permute_multi_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp diff --git a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py index a06214b981..5ea9472858 100644 --- a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py +++ b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py @@ -19,10 +19,16 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu" + ) try: torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" + ) except OSError: # This is for forward compatibility (new torch.package + old backend) # We should be able to remove it after this diff is picked up by all backend diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h new file mode 100644 index 0000000000..c2f1a704a0 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +class PermuteMultiEmbeddingOp + : public torch::autograd::Function { + public: + static variable_list forward( + AutogradContext* ctx, + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths); + + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output); +}; + +std::vector permute_multi_embedding_cpu( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute); + +std::vector permute_multi_embedding_meta( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute); + +std::vector permute_multi_embedding_gpu( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute); +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp new file mode 100644 index 0000000000..90b9403d12 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" +#include +#include + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +variable_list PermuteMultiEmbeddingOp::forward( + AutogradContext* ctx, + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths) { + ctx->saved_data["permutes"] = permutes; + ctx->saved_data["in_lengths"] = in_lengths; + ctx->saved_data["out_lengths"] = out_lengths; + + /* + select the correct dispatched (cpu/gpu) forward function + the cpu/gup function needs to be registered in the dispatcher, + e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc. + */ + const auto permute_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") + .typed(); + + return permute_op.call(pooled_embs, permutes, in_lengths, out_lengths, false); +} + +variable_list PermuteMultiEmbeddingOp::backward( + AutogradContext* ctx, + variable_list grad_output) { + const auto permutes = ctx->saved_data["permutes"].toIntVector(); + const auto in_lengths = ctx->saved_data["in_lengths"].toIntVector(); + const auto out_lengths = ctx->saved_data["out_lengths"].toIntVector(); + + /* + select the correct dispatched (cpu/gpu) backward function + the cpu/gup function needs to be registered in the dispatcher, + e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc. + */ + const auto permute_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") + .typed(); + auto grad_input = + permute_op.call(grad_output, permutes, out_lengths, in_lengths, true); + grad_input.push_back(torch::autograd::Variable()); // permutes + grad_input.push_back(torch::autograd::Variable()); // in_lengths + grad_input.push_back(torch::autograd::Variable()); // out_lengths + return grad_input; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu new file mode 100644 index 0000000000..26467e73dd --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu @@ -0,0 +1,232 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +// Kernerl for permute pooled embedding op. +// This kernel is moving D elements per warp. +template +__global__ void permute_multi_embs_kernel( + const scalar_t** __restrict__ inputs, + const scalar_t** __restrict__ outputs, + const int64_t* __restrict__ permutes, + const int64_t* __restrict__ input_lengths, + const int64_t* __restrict__ output_lengths, + const int64_t batch_size, + const int64_t permute_size, + const bool reverse_permute) { + // workers in a warp handle a feature + const int32_t worker_id = threadIdx.x % warpSize; + const int32_t worker_size = warpSize; + const int32_t permute_id = + blockIdx.x * (blockDim.x / warpSize) + threadIdx.x / warpSize; + const int32_t batch_id = blockIdx.y + gridDim.y * blockIdx.z; + if (batch_id >= batch_size) { + return; + } + if (permute_id >= permute_size) { + return; + } + + // parse permutes + const int64_t params = 6; + int64_t in_tensor, out_tensor, in_start, out_start, length, jump; + if (reverse_permute) { + out_tensor = permutes[params * permute_id]; + in_tensor = permutes[params * permute_id + 1]; + out_start = permutes[params * permute_id + 2]; + in_start = permutes[params * permute_id + 3]; + } else { + in_tensor = permutes[params * permute_id]; + out_tensor = permutes[params * permute_id + 1]; + in_start = permutes[params * permute_id + 2]; + out_start = permutes[params * permute_id + 3]; + } + length = permutes[params * permute_id + 4]; + jump = permutes[params * permute_id + 5]; + + if (worker_id >= length) { + return; + } + if (reverse_permute && jump < 0) { + return; + } + + // locate the batch_id + int64_t in_length = input_lengths[in_tensor]; + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; + input_ptr += batch_id * in_length; + + int64_t out_length = output_lengths[out_tensor]; + scalar_t* output_ptr = (scalar_t*)outputs[out_tensor]; + output_ptr += batch_id * out_length; + + // printf( // debug print + // "input_tensors[%ld][%ld][%d] = %f\n", + // in_tensor, + // batch_id, + // in_start + worker_id, + // input_ptr[in_start + worker_id]); + if (fbgemm_gpu::is_aligned>( + &output_ptr[out_start]) && + fbgemm_gpu::is_aligned>( + &input_ptr[in_start])) { + const int32_t vec_size = 4; + const int32_t loop_end = length / (vec_size) * (vec_size); + for (int32_t i = worker_id * vec_size; i < loop_end; + i += worker_size * vec_size) { + fbgemm_gpu::Vec4T::copy( + &input_ptr[in_start + i], &output_ptr[out_start + i]); + } + // Use elementwise access for the last incomplete vector. + for (int32_t i = loop_end + worker_id; i < length; i += worker_size) { + output_ptr[out_start + i] = input_ptr[in_start + i]; + } + } else { // Fallback if not aligned. + for (int32_t i = worker_id; i < length; i += worker_size) { + output_ptr[out_start + i] = input_ptr[in_start + i]; + } + } + + // for reverse_permute (backward) with jump + while (reverse_permute && jump > 0 && jump < permute_size) { + in_tensor = permutes[params * jump + 1]; + in_start = permutes[params * jump + 3]; + length = permutes[params * jump + 4]; + jump = -permutes[params * jump + 5]; + + int64_t in_length = input_lengths[in_tensor]; + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; + input_ptr += batch_id * in_length; + + for (int32_t i = worker_id; i < length; i += worker_size) { + output_ptr[out_start + i] += input_ptr[in_start + i]; + } + } +} + +template +Tensor from_vec(const std::vector input) { + const auto int_opts = + torch::TensorOptions().dtype(torch::kInt64).pinned_memory(true); + Tensor output = at::empty({static_cast(input.size())}, int_opts); + // Ensure that output is contiguous + TORCH_CHECK(output.is_contiguous()); + std::memcpy( + output.data_ptr(), input.data(), input.size() * sizeof(index_t)); + return output; +} + +template +Tensor tensors_ptr(const at::TensorList& tensors) { + auto size = tensors.size(); + Tensor ptr_tensor = at::empty( + {static_cast(size * sizeof(scalar_t*))}, + at::TensorOptions().dtype(tensors[0].scalar_type()).pinned_memory(true)); + + // Ensure that ptr_tensor is contiguous + TORCH_CHECK(ptr_tensor.is_contiguous()); + auto tp = reinterpret_cast(ptr_tensor.data_ptr()); + for (int32_t i = 0; i < tensors.size(); i++) { + tp[i] = tensors[i].data_ptr(); + } + // Ensure that ptr_tensor is contiguous + TORCH_CHECK(ptr_tensor.is_contiguous()); + return ptr_tensor; +} + +std::vector permute_multi_embedding_gpu( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute) { + const int64_t permute_param = 6; + int64_t num_of_input_tensors = in_lengths.size(); + int64_t num_of_output_tensors = out_lengths.size(); + int64_t batch_size = pooled_embs[0].size(0); + int64_t permute_size = permutes.size() / permute_param; + + // check input tensors + std::vector inputs; + inputs.reserve(pooled_embs.size()); + for (int32_t i = 0; i < num_of_input_tensors; i++) { + Tensor cont_tensor = pooled_embs[i].contiguous(); + inputs.push_back(cont_tensor); + TENSORS_ON_SAME_DEVICE(cont_tensor, pooled_embs[i]); + TENSORS_ON_SAME_DEVICE(pooled_embs[i], pooled_embs[0]); + } + + // initiate output tensors + std::vector outputs; + outputs.reserve(num_of_output_tensors); + for (int32_t i = 0; i < num_of_output_tensors; i++) { + Tensor output = + at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options()); + outputs.push_back(output); + } + + auto permutes_tensor = from_vec(permutes); + auto in_lengths_tensor = from_vec(in_lengths); + auto out_lengths_tensor = from_vec(out_lengths); + + auto device = pooled_embs[0].device(); + permutes_tensor = permutes_tensor.to(device, /*non_blocking=*/true); + in_lengths_tensor = in_lengths_tensor.to(device, /*non_blocking=*/true); + out_lengths_tensor = out_lengths_tensor.to(device, /*non_blocking=*/true); + + // This kernel is moving D elements per warp. + // We are launching ( div_round_up(T, warp_per_block), B ) blocks. + // The grid z dimension is also used by batch_size in case it's greater than + // 65535. + const int32_t warp_per_block = + fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize; + const int32_t max_grid_dim_y = + 32768; // The CUDA maximum is 65535, not a power of 2. + const dim3 threads(fbgemm_gpu::kMaxThreads); + const dim3 blocks( + fbgemm_gpu::div_round_up(permute_size, warp_per_block), + std::min(static_cast(batch_size), max_grid_dim_y), + (batch_size + max_grid_dim_y - 1) / max_grid_dim_y); + + FBGEMM_DISPATCH_FLOATING_TYPES( + pooled_embs[0].scalar_type(), "permute_multi_embedding", [&] { + Tensor in_tensor = tensors_ptr(inputs); + Tensor out_tensor = tensors_ptr(outputs); + in_tensor = in_tensor.to(device, /*non_blocking=*/true); + out_tensor = out_tensor.to(device, /*non_blocking=*/true); + permute_multi_embs_kernel + <<>>( + (const scalar_t**)in_tensor.data_ptr(), + (const scalar_t**)out_tensor.data_ptr(), + permutes_tensor.data_ptr(), + in_lengths_tensor.data_ptr(), + out_lengths_tensor.data_ptr(), + batch_size, + permute_size, + reverse_permute); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return outputs; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp new file mode 100644 index 0000000000..59937cf893 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +std::vector permute_multi_embedding_cpu( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute) { + int64_t batch_size = pooled_embs[0].size(0); + + std::vector outputs; + outputs.reserve(out_lengths.size()); + for (const auto i : c10::irange(out_lengths.size())) { + outputs.push_back( + at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options())); + } + + int64_t in_tensor, out_tensor, in_start, out_start, length, jump; + const int64_t param = 6; + for (const auto i : c10::irange(permutes.size() / param)) { + if (reverse_permute) { + out_tensor = permutes[i * param]; + in_tensor = permutes[i * param + 1]; + out_start = permutes[i * param + 2]; + in_start = permutes[i * param + 3]; + jump = permutes[i * param + 5]; + } else { + in_tensor = permutes[i * param]; + out_tensor = permutes[i * param + 1]; + in_start = permutes[i * param + 2]; + out_start = permutes[i * param + 3]; + } + length = permutes[i * param + 4]; + if (reverse_permute && jump < 0) { + for (const auto b : c10::irange(batch_size)) { + for (const auto j : c10::irange(length)) { + outputs[out_tensor][b][j + out_start] += + pooled_embs[in_tensor][b][j + in_start]; + } + } + } else { + for (const auto b : c10::irange(batch_size)) { + for (const auto j : c10::irange(length)) { + outputs[out_tensor][b][j + out_start] = + pooled_embs[in_tensor][b][j + in_start]; + } + } + } + } + return outputs; +} + +std::vector permute_multi_embedding_meta( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths, + const bool& reverse_permute) { + int64_t batch_size = pooled_embs[0].size(0); + + std::vector outputs; + outputs.reserve(out_lengths.size()); + for (const auto i : c10::irange(out_lengths.size())) { + outputs.push_back( + at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options())); + } + return outputs; +} + +std::vector permute_multi_embedding_autograd( + const at::TensorList& pooled_embs, + const std::vector& permutes, + const std::vector& in_lengths, + const std::vector& out_lengths) { + return PermuteMultiEmbeddingOp::apply( + pooled_embs, permutes, in_lengths, out_lengths); +} + +} // namespace fbgemm_gpu + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // register the forward function for internal (autograd) usage + m.def( + "permute_multi_embedding_function(Tensor[] pooled_embs, int[] permutes, SymInt[] in_lengths, SymInt[] out_lengths, bool reverse=False) -> Tensor[]", + {PT2_COMPLIANT_TAG}); + + // register the main function for external usage + m.def( + "permute_multi_embedding(Tensor[] pooled_embs, int[] permutes, SymInt[] in_lengths, SymInt[] out_lengths) -> Tensor[]", + {PT2_COMPLIANT_TAG}); + + // dispatch the forward function to CPU for internal (autograd) usage + DISPATCH_TO_CPU( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_cpu); + + // dispatch the forward function to CPU for internal (autograd) usage + DISPATCH_TO_META( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_meta); + + // dispath the main function to Autograd for external usage + DISPATCH_TO_AUTOGRAD( + "permute_multi_embedding", fbgemm_gpu::permute_multi_embedding_autograd); + + // dispath the main function to Autograd for external usage + DISPATCH_TO_CUDA( + "permute_multi_embedding", fbgemm_gpu::permute_multi_embedding_autograd); +} diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp new file mode 100644 index 0000000000..4f137f6926 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_gpu.cpp @@ -0,0 +1,16 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // dispatch the forward function to GPU for internal (autograd) usage + DISPATCH_TO_CUDA( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_gpu); +}