Skip to content

Commit

Permalink
optimize ESUHM compute by fusing embedding lookups and padding
Browse files Browse the repository at this point in the history
Summary: Fuse embedding lookup and padding to eliminate reductant memory copy. This diff will both improve QPS and reduce memory footprint.

Reviewed By: zhangruiskyline

Differential Revision: D33421488

fbshipit-source-id: e54a8641a10ab3b0afb09cbef0e4a8b983ca19e0
  • Loading branch information
xing-liu authored and facebook-github-bot committed Jan 12, 2022
1 parent 2f6c894 commit 5b945ac
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 145 deletions.
125 changes: 0 additions & 125 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py

This file was deleted.

4 changes: 2 additions & 2 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1422,14 +1422,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"jagged_2d_to_dense(Tensor values, Tensor offsets, int max_sequence_length) -> Tensor");
m.def(
"jagged_1d_to_dense(Tensor values, Tensor offsets, int max_sequence_length, int padding_value) -> Tensor");

m.def(
"stacked_jagged_2d_to_dense_forward(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key) -> (Tensor[], Tensor[])");
m.def(
"stacked_jagged_2d_to_dense_backward(int B, int D, int total_L, Tensor[] grad_padded_values_per_key, Tensor[] offsets_tensor_per_key, int[] offset_per_key) -> Tensor");

m.def(
"stacked_jagged_1d_to_dense(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key, int padding_value) -> Tensor[]");
m.def(
"stacked_jagged_2d_to_dense(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key) -> Tensor[]");
m.def(
"histogram_binning_calibration(Tensor logit, Tensor bin_num_examples, Tensor bin_num_positives, float positive_weight, float lower_bound, float upper_bound, int bin_ctr_in_use_after, float bin_ctr_weight_value) -> (Tensor, Tensor)");
m.def(
Expand Down
63 changes: 63 additions & 0 deletions fbgemm_gpu/src/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <Python.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/library.h>
#include <stdexcept> // for logic_error

using Tensor = at::Tensor;

Expand Down Expand Up @@ -101,6 +103,65 @@ Tensor jagged_2d_to_dense_gpu(
values, offsets, static_cast<int32_t>(max_sequence_length))[0];
}

class StackedJagged2DToDenseGPUOp
: public torch::autograd::Function<StackedJagged2DToDenseGPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
Tensor values,
Tensor lengths,
const std::vector<int64_t>& offset_per_key,
const std::vector<int64_t>& max_lengths_per_key) {
int32_t total_L = values.size(0);
ctx->saved_data["B"] = lengths.size(1);
ctx->saved_data["D"] = values.size(1);
ctx->saved_data["total_L"] = total_L;
ctx->saved_data["offset_per_key"] = offset_per_key;

auto [padded_values_per_key, offsets_tensor_per_key] =
stacked_jagged_2d_to_dense_forward_cuda(
values, lengths, offset_per_key, max_lengths_per_key);
ctx->saved_data["offsets_tensor_per_key"] = offsets_tensor_per_key;

return padded_values_per_key;
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto B = ctx->saved_data["B"].toInt();
auto D = ctx->saved_data["D"].toInt();
auto total_L = ctx->saved_data["total_L"].toInt();
auto offset_per_key = ctx->saved_data["offset_per_key"].toIntVector();
auto offsets_tensor_per_key =
ctx->saved_data["offsets_tensor_per_key"].toTensorVector();

using torch::autograd::Variable;
auto grad_values = stacked_jagged_2d_to_dense_backward_cuda(
B, D, total_L, grad_outputs, offsets_tensor_per_key, offset_per_key);
return {
grad_values,
Variable(), // lengths
Variable(), // offset_per_key
Variable() // max_lengths_per_key
};
}
};

std::vector<Tensor> stacked_jagged_2d_to_dense_gpu(
Tensor values,
Tensor lengths,
const std::vector<int64_t>& offset_per_key,
const std::vector<int64_t>& max_lengths_per_key) {
TENSOR_ON_CUDA_GPU(values);
TENSOR_ON_CUDA_GPU(lengths);
TENSORS_ON_SAME_DEVICE(values, lengths);
TORCH_CHECK(values.dim() == 2);
TORCH_CHECK(lengths.dim() == 2);
return StackedJagged2DToDenseGPUOp::apply(
values, lengths, offset_per_key, max_lengths_per_key);
}

} // namespace fbgemm_gpu

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
Expand Down Expand Up @@ -129,6 +190,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense_gpu);
DISPATCH_TO_CUDA(
"stacked_jagged_1d_to_dense", fbgemm_gpu::stacked_jagged_1d_to_dense_gpu);
DISPATCH_TO_CUDA(
"stacked_jagged_2d_to_dense", fbgemm_gpu::stacked_jagged_2d_to_dense_gpu);
DISPATCH_TO_CUDA(
"stacked_jagged_2d_to_dense_forward",
fbgemm_gpu::stacked_jagged_2d_to_dense_forward_cuda);
Expand Down
30 changes: 12 additions & 18 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@
import hypothesis.strategies as st
import numpy as np
import torch
from fbgemm_gpu.sparse_ops import (
jagged_1d_to_dense,
jagged_2d_to_dense,
stacked_jagged_1d_to_dense,
stacked_jagged_2d_to_dense,
)
from hypothesis import Verbosity, given, settings

try:
Expand Down Expand Up @@ -834,7 +828,7 @@ def test_jagged_2d_to_dense(
values = ref_values.clone().half().detach().requires_grad_(True)
else:
values = ref_values.clone().detach().requires_grad_(True)
output_values = jagged_2d_to_dense(
output_values = torch.ops.fbgemm.jagged_2d_to_dense(
values=values,
offsets=offsets,
max_sequence_length=max_sequence_length,
Expand All @@ -850,7 +844,7 @@ def test_jagged_2d_to_dense(
values = ref_values.clone().detach().requires_grad_(True)
offsets = offsets.cuda()
ref_output_values = ref_output_values.cuda()
output_values = jagged_2d_to_dense(
output_values = torch.ops.fbgemm.jagged_2d_to_dense(
values=values,
offsets=offsets,
max_sequence_length=max_sequence_length,
Expand Down Expand Up @@ -880,7 +874,7 @@ def test_jagged_2d_to_dense_truncation(self) -> None:

# test cpu forward
values = ref_values.clone().detach().requires_grad_(True)
output_values = jagged_2d_to_dense(
output_values = torch.ops.fbgemm.jagged_2d_to_dense(
values=values,
offsets=offsets,
max_sequence_length=max_sequence_length,
Expand All @@ -893,7 +887,7 @@ def test_jagged_2d_to_dense_truncation(self) -> None:
values = ref_values.clone().detach().requires_grad_(True)
offsets = offsets.cuda()
ref_output_values = ref_output_values.cuda()
output_values = jagged_2d_to_dense(
output_values = torch.ops.fbgemm.jagged_2d_to_dense(
values=values,
offsets=offsets,
max_sequence_length=max_sequence_length,
Expand Down Expand Up @@ -942,14 +936,14 @@ def test_stacked_jagged_2d_to_dense(
lengths = lengths.view(T, B)

values = ref_values.clone().detach().requires_grad_(True)
output_values_per_table = stacked_jagged_2d_to_dense(
output_values_per_table = torch.ops.fbgemm.stacked_jagged_2d_to_dense(
values=values,
lengths=lengths,
offset_per_key=[0]
+ np.cumsum([lengths[t].sum().item() for t in range(T)]).tolist(),
max_lengths_per_key=[max_sequence_length] * T,
)
ref_output_values = jagged_2d_to_dense(
ref_output_values = torch.ops.fbgemm.jagged_2d_to_dense(
values=ref_values,
offsets=offsets,
max_sequence_length=max_sequence_length,
Expand Down Expand Up @@ -1030,7 +1024,7 @@ def var_list_to_coo(

# test cpu forward
values = ref_values.clone().detach().requires_grad_(False)
output_values = jagged_1d_to_dense(
output_values = torch.ops.fbgemm.jagged_1d_to_dense(
values=values,
offsets=offsets,
max_sequence_length=max_sequence_length,
Expand All @@ -1044,7 +1038,7 @@ def var_list_to_coo(
values = ref_values.clone().detach().requires_grad_(False)
offsets = offsets.cuda()
ref_output_values = ref_output_values.cuda()
output_values = jagged_1d_to_dense(
output_values = torch.ops.fbgemm.jagged_1d_to_dense(
values=values,
offsets=offsets,
max_sequence_length=max_sequence_length,
Expand All @@ -1062,7 +1056,7 @@ def test_jagged_1d_to_dense_truncation(self) -> None:

# test cpu forward
values = ref_values.clone().detach().requires_grad_(False)
output = jagged_1d_to_dense(
output = torch.ops.fbgemm.jagged_1d_to_dense(
values=values,
offsets=offsets,
max_sequence_length=1,
Expand All @@ -1076,7 +1070,7 @@ def test_jagged_1d_to_dense_truncation(self) -> None:
values = ref_values.clone().detach().requires_grad_(False)
offsets = offsets.cuda()
ref_output = ref_output.cuda()
output = jagged_1d_to_dense(
output = torch.ops.fbgemm.jagged_1d_to_dense(
values=values,
offsets=offsets,
max_sequence_length=1,
Expand Down Expand Up @@ -1142,15 +1136,15 @@ def var_list_to_coo(
ref_values = torch.randint(low=0, high=1000000000, size=(total_lengths,)).cuda()

values = ref_values.clone().detach().requires_grad_(False)
output_values_per_table = stacked_jagged_1d_to_dense(
output_values_per_table = torch.ops.fbgemm.stacked_jagged_1d_to_dense(
values=values,
lengths=lengths,
offset_per_key=[0]
+ np.cumsum([lengths[t].sum().item() for t in range(T)]).tolist(),
max_lengths_per_key=[max_sequence_length] * T,
padding_value=padding_value,
)
ref_output_values = jagged_1d_to_dense(
ref_output_values = torch.ops.fbgemm.jagged_1d_to_dense(
values=ref_values,
offsets=offsets,
max_sequence_length=max_sequence_length,
Expand Down

0 comments on commit 5b945ac

Please sign in to comment.