Skip to content

Commit

Permalink
implementation of fbgemm op - permute_multi_embedding (#2738)
Browse files Browse the repository at this point in the history
Summary:

X-link: pytorch/torchrec#2120

# context
* current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward
* it has several limitations:
a) it has to be a full permute, duplicates are not supported;
b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient;
c) the function output a single KT which requires a split operation
* there is some attempt to support duplicated outputs, but the backward doesn't work
* this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support

# notes
* this diff focuses on the implemenation and test of the operator
* performance analysis and benchmark are in the next diff

# operator example usage
* used in python
```
# test inputs: 3 KTs with batch_size=2048
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
values = [
    torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
    for lens in lengths
]

# target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]

# accessorial arguments to the op/kernel
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
    keys, lengths, groups
)

# arguments
outputs = torch.ops.fbgemm.permute_multi_embedding(
    values, permutes, in_lengths, out_lengths
)
```
* permutes
```
permutes = tensor(
            [
                [0, 0, 0, 0, 3, 4],  # f1
                [1, 0, 0, 3, 5, 0],  # f3
                [0, 1, 3, 0, 4, 0],  # f2
                [1, 2, 5, 0, 6, 0],  # f4
                [0, 2, 0, 6, 3, -6],  # f1
                [2, 2, 0, 9, 8, 0],  # f6
                [0, 3, 0, 0, 3, -8],  # f1
                [1, 3, 11, 3, 7, 0],  # f5
            ]
)
```

# details
1. from the above example usage, we can clearly see that the operatior takes in the following:
a) values: List[torch.Tensor], which represents the input KTs
b) permutes: torch.Tensor, which contains the permute information, will be explained later
c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead
d) in_lengths: torch.Tensor, lengths of input tensors, which is on device
e) out_lengths: torch.Tensor, lengths of output tensors, which is on device
2. the operator returns a list of tensors, which represents the permuted KTs
3. `permute` is the most critical argument in this operator:
a) 2-D tensor
b) each row represents a key (feature) permute move
c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump]
d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors

Differential Revision: D57055616
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jun 22, 2024
1 parent 7f77444 commit f5d6d39
Show file tree
Hide file tree
Showing 7 changed files with 508 additions and 0 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ set(fbgemm_gpu_sources_static_cpu
codegen/training/backward/embedding_backward_dense_host_cpu.cpp
codegen/utils/embedding_bounds_check_host_cpu.cpp
src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
src/permute_pooled_embedding_ops/permute_multi_embedding_ops_cpu.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
)
try:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
)
except OSError:
# This is for forward compatibility (new torch.package + old backend)
# We should be able to remove it after this diff is picked up by all backend
Expand Down
64 changes: 64 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <ATen/ATen.h>
#include <torch/csrc/api/include/torch/types.h>
#include <torch/csrc/autograd/custom_function.h>

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/ops_utils.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

namespace fbgemm_gpu {

using Tensor = at::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::variable_list;

using Tensor = at::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::variable_list;

class PermuteMultiEmbeddingOp
: public torch::autograd::Function<PermuteMultiEmbeddingOp> {
public:
static variable_list forward(
AutogradContext* ctx,
const at::TensorList& pooled_embs,
const std::vector<int64_t>& permutes,
const std::vector<int64_t>& in_lengths,
const std::vector<int64_t>& out_lengths);

static variable_list backward(
AutogradContext* ctx,
variable_list grad_output);
};

std::vector<Tensor> permute_multi_embedding_cpu(
const at::TensorList& pooled_embs,
const std::vector<int64_t>& permutes,
const std::vector<int64_t>& in_lengths,
const std::vector<int64_t>& out_lengths,
const bool& reverse_permute);

std::vector<Tensor> permute_multi_embedding_meta(
const at::TensorList& pooled_embs,
const std::vector<int64_t>& permutes,
const std::vector<int64_t>& in_lengths,
const std::vector<int64_t>& out_lengths,
const bool& reverse_permute);

std::vector<Tensor> permute_multi_embedding_gpu(
const at::TensorList& pooled_embs,
const std::vector<int64_t>& permutes,
const std::vector<int64_t>& in_lengths,
const std::vector<int64_t>& out_lengths,
const bool& reverse_permute);
} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fbgemm_gpu/permute_multi_embedding_function.h"
#include <cstdint>
#include <iostream>

namespace fbgemm_gpu {

using Tensor = at::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::variable_list;

variable_list PermuteMultiEmbeddingOp::forward(
AutogradContext* ctx,
const at::TensorList& pooled_embs,
const std::vector<int64_t>& permutes,
const std::vector<int64_t>& in_lengths,
const std::vector<int64_t>& out_lengths) {
ctx->saved_data["permutes"] = permutes;
ctx->saved_data["in_lengths"] = in_lengths;
ctx->saved_data["out_lengths"] = out_lengths;

/*
select the correct dispatched (cpu/gpu) forward function
the cpu/gup function needs to be registered in the dispatcher,
e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc.
*/
const auto permute_op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "")
.typed<decltype(permute_multi_embedding_cpu)>();

return permute_op.call(pooled_embs, permutes, in_lengths, out_lengths, false);
}

variable_list PermuteMultiEmbeddingOp::backward(
AutogradContext* ctx,
variable_list grad_output) {
const auto permutes = ctx->saved_data["permutes"].toIntVector();
const auto in_lengths = ctx->saved_data["in_lengths"].toIntVector();
const auto out_lengths = ctx->saved_data["out_lengths"].toIntVector();

/*
select the correct dispatched (cpu/gpu) backward function
the cpu/gup function needs to be registered in the dispatcher,
e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc.
*/
const auto permute_op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "")
.typed<decltype(permute_multi_embedding_cpu)>();
auto grad_input =
permute_op.call(grad_output, permutes, out_lengths, in_lengths, true);
grad_input.push_back(torch::autograd::Variable()); // permutes
grad_input.push_back(torch::autograd::Variable()); // in_lengths
grad_input.push_back(torch::autograd::Variable()); // out_lengths
return grad_input;
}

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <ostream>

#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/permute_multi_embedding_function.h"

using Tensor = at::Tensor;

namespace fbgemm_gpu {

// Kernerl for permute pooled embedding op.
// This kernel is moving D elements per warp.
template <typename scalar_t>
__global__ void permute_multi_embs_kernel(
const scalar_t** __restrict__ inputs,
const scalar_t** __restrict__ outputs,
const int64_t* __restrict__ permutes,
const int64_t* __restrict__ input_lengths,
const int64_t* __restrict__ output_lengths,
const int64_t batch_size,
const int64_t permute_size,
const bool reverse_permute) {
// workers in a warp handle a feature
const int32_t worker_id = threadIdx.x % warpSize;
const int32_t worker_size = warpSize;
const int32_t permute_id =
blockIdx.x * (blockDim.x / warpSize) + threadIdx.x / warpSize;
const int32_t batch_id = blockIdx.y + gridDim.y * blockIdx.z;
if (batch_id >= batch_size) {
return;
}
if (permute_id >= permute_size) {
return;
}

// parse permutes
const int64_t params = 6;
int64_t in_tensor, out_tensor, in_start, out_start, length, jump;
if (reverse_permute) {
out_tensor = permutes[params * permute_id];
in_tensor = permutes[params * permute_id + 1];
out_start = permutes[params * permute_id + 2];
in_start = permutes[params * permute_id + 3];
} else {
in_tensor = permutes[params * permute_id];
out_tensor = permutes[params * permute_id + 1];
in_start = permutes[params * permute_id + 2];
out_start = permutes[params * permute_id + 3];
}
length = permutes[params * permute_id + 4];
jump = permutes[params * permute_id + 5];

if (worker_id >= length) {
return;
}
if (reverse_permute && jump < 0) {
return;
}

// locate the batch_id
int64_t in_length = input_lengths[in_tensor];
scalar_t* input_ptr = (scalar_t*)inputs[in_tensor];
input_ptr += batch_id * in_length;

int64_t out_length = output_lengths[out_tensor];
scalar_t* output_ptr = (scalar_t*)outputs[out_tensor];
output_ptr += batch_id * out_length;

// printf( // debug print
// "input_tensors[%ld][%ld][%d] = %f\n",
// in_tensor,
// batch_id,
// in_start + worker_id,
// input_ptr[in_start + worker_id]);
if (fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(
&output_ptr[out_start]) &&
fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(
&input_ptr[in_start])) {
const int32_t vec_size = 4;
const int32_t loop_end = length / (vec_size) * (vec_size);
for (int32_t i = worker_id * vec_size; i < loop_end;
i += worker_size * vec_size) {
fbgemm_gpu::Vec4T<scalar_t>::copy(
&input_ptr[in_start + i], &output_ptr[out_start + i]);
}
// Use elementwise access for the last incomplete vector.
for (int32_t i = loop_end + worker_id; i < length; i += worker_size) {
output_ptr[out_start + i] = input_ptr[in_start + i];
}
} else { // Fallback if not aligned.
for (int32_t i = worker_id; i < length; i += worker_size) {
output_ptr[out_start + i] = input_ptr[in_start + i];
}
}

// for reverse_permute (backward) with jump
while (reverse_permute && jump > 0 && jump < permute_size) {
in_tensor = permutes[params * jump + 1];
in_start = permutes[params * jump + 3];
length = permutes[params * jump + 4];
jump = -permutes[params * jump + 5];

int64_t in_length = input_lengths[in_tensor];
scalar_t* input_ptr = (scalar_t*)inputs[in_tensor];
input_ptr += batch_id * in_length;

for (int32_t i = worker_id; i < length; i += worker_size) {
output_ptr[out_start + i] += input_ptr[in_start + i];
}
}
}

template <typename index_t>
Tensor from_vec(const std::vector<index_t> input) {
const auto int_opts =
torch::TensorOptions().dtype(torch::kInt64).pinned_memory(true);
Tensor output = at::empty({static_cast<index_t>(input.size())}, int_opts);
// Ensure that output is contiguous
TORCH_CHECK(output.is_contiguous());
std::memcpy(
output.data_ptr<index_t>(), input.data(), input.size() * sizeof(index_t));
return output;
}

template <typename scalar_t>
Tensor tensors_ptr(const at::TensorList& tensors) {
auto size = tensors.size();
Tensor ptr_tensor = at::empty(
{static_cast<long>(size * sizeof(scalar_t*))},
at::TensorOptions().dtype(tensors[0].scalar_type()).pinned_memory(true));

// Ensure that ptr_tensor is contiguous
TORCH_CHECK(ptr_tensor.is_contiguous());
auto tp = reinterpret_cast<scalar_t**>(ptr_tensor.data_ptr());
for (int32_t i = 0; i < tensors.size(); i++) {
tp[i] = tensors[i].data_ptr<scalar_t>();
}
// Ensure that ptr_tensor is contiguous
TORCH_CHECK(ptr_tensor.is_contiguous());
return ptr_tensor;
}

std::vector<Tensor> permute_multi_embedding_gpu(
const at::TensorList& pooled_embs,
const std::vector<int64_t>& permutes,
const std::vector<int64_t>& in_lengths,
const std::vector<int64_t>& out_lengths,
const bool& reverse_permute) {
const int64_t permute_param = 6;
int64_t num_of_input_tensors = in_lengths.size();
int64_t num_of_output_tensors = out_lengths.size();
int64_t batch_size = pooled_embs[0].size(0);
int64_t permute_size = permutes.size() / permute_param;

// check input tensors
std::vector<Tensor> inputs;
inputs.reserve(pooled_embs.size());
for (int32_t i = 0; i < num_of_input_tensors; i++) {
Tensor cont_tensor = pooled_embs[i].contiguous();
inputs.push_back(cont_tensor);
TENSORS_ON_SAME_DEVICE(cont_tensor, pooled_embs[i]);
TENSORS_ON_SAME_DEVICE(pooled_embs[i], pooled_embs[0]);
}

// initiate output tensors
std::vector<Tensor> outputs;
outputs.reserve(num_of_output_tensors);
for (int32_t i = 0; i < num_of_output_tensors; i++) {
Tensor output =
at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options());
outputs.push_back(output);
}

auto permutes_tensor = from_vec<int64_t>(permutes);
auto in_lengths_tensor = from_vec<int64_t>(in_lengths);
auto out_lengths_tensor = from_vec<int64_t>(out_lengths);

auto device = pooled_embs[0].device();
permutes_tensor = permutes_tensor.to(device, /*non_blocking=*/true);
in_lengths_tensor = in_lengths_tensor.to(device, /*non_blocking=*/true);
out_lengths_tensor = out_lengths_tensor.to(device, /*non_blocking=*/true);

// This kernel is moving D elements per warp.
// We are launching ( div_round_up(T, warp_per_block), B ) blocks.
// The grid z dimension is also used by batch_size in case it's greater than
// 65535.
const int32_t warp_per_block =
fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize;
const int32_t max_grid_dim_y =
32768; // The CUDA maximum is 65535, not a power of 2.
const dim3 threads(fbgemm_gpu::kMaxThreads);
const dim3 blocks(
fbgemm_gpu::div_round_up(permute_size, warp_per_block),
std::min(static_cast<int32_t>(batch_size), max_grid_dim_y),
(batch_size + max_grid_dim_y - 1) / max_grid_dim_y);

FBGEMM_DISPATCH_FLOATING_TYPES(
pooled_embs[0].scalar_type(), "permute_multi_embedding", [&] {
Tensor in_tensor = tensors_ptr<scalar_t>(inputs);
Tensor out_tensor = tensors_ptr<scalar_t>(outputs);
in_tensor = in_tensor.to(device, /*non_blocking=*/true);
out_tensor = out_tensor.to(device, /*non_blocking=*/true);
permute_multi_embs_kernel<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
(const scalar_t**)in_tensor.data_ptr(),
(const scalar_t**)out_tensor.data_ptr(),
permutes_tensor.data_ptr<int64_t>(),
in_lengths_tensor.data_ptr<int64_t>(),
out_lengths_tensor.data_ptr<int64_t>(),
batch_size,
permute_size,
reverse_permute);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return outputs;
}

} // namespace fbgemm_gpu
Loading

0 comments on commit f5d6d39

Please sign in to comment.