diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py deleted file mode 100644 index 1a783c08e0..0000000000 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Meta Platforms, Inc. and its affiliates. -# All rights reserved. -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import List, Tuple - -import torch - -try: - # pyre-ignore[21] - from fbgemm_gpu import open_source # noqa: F401 -except Exception: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") - - -class _StackedJagged2DToDenseFunction(torch.autograd.Function): - @staticmethod - # pyre-fixme[14] - def forward( - # pyre-fixme[2] - ctx, - values: torch.Tensor, - lengths: torch.Tensor, - offset_per_key: List[int], - max_lengths_per_key: List[int], - ) -> Tuple[torch.Tensor]: - ctx.B = lengths.size(1) - ctx.D = values.size(1) - ctx.total_L = values.size(0) - ctx.offset_per_key = offset_per_key - ( - padded_values_per_key, - offsets_tensor_per_key, - ) = torch.ops.fbgemm.stacked_jagged_2d_to_dense_forward( - values, - lengths, - offset_per_key, - max_lengths_per_key, - ) - ctx.offsets_tensor_per_key = offsets_tensor_per_key - return tuple(padded_values_per_key) - - @staticmethod - def backward( - # pyre-fixme[2] - ctx, - *grad_padded_values_per_key: torch.Tensor - ) -> Tuple[torch.Tensor, None, None, None]: - B = ctx.B - D = ctx.D - total_L = ctx.total_L - offset_per_key = ctx.offset_per_key - offsets_tensor_per_key = ctx.offsets_tensor_per_key - grad_values = torch.ops.fbgemm.stacked_jagged_2d_to_dense_backward( - B, - D, - total_L, - list(grad_padded_values_per_key), - offsets_tensor_per_key, - offset_per_key, - ) - return grad_values, None, None, None - - -def jagged_1d_to_dense( - values: torch.Tensor, - offsets: torch.Tensor, - max_sequence_length: int, - padding_value: int, -) -> torch.Tensor: - return torch.ops.fbgemm.jagged_1d_to_dense( - values=values, - offsets=offsets, - max_sequence_length=max_sequence_length, - padding_value=padding_value, - ) - - -def jagged_2d_to_dense( - values: torch.Tensor, - offsets: torch.Tensor, - max_sequence_length: int, -) -> torch.Tensor: - return torch.ops.fbgemm.jagged_2d_to_dense( - values=values, - offsets=offsets, - max_sequence_length=max_sequence_length, - ) - - -def stacked_jagged_1d_to_dense( - values: torch.Tensor, - lengths: torch.Tensor, - offset_per_key: List[int], - max_lengths_per_key: List[int], - padding_value: int, -) -> List[torch.Tensor]: - return torch.ops.fbgemm.stacked_jagged_1d_to_dense( - values=values, - lengths=lengths, - offset_per_key=offset_per_key, - max_lengths_per_key=max_lengths_per_key, - padding_value=padding_value, - ) - - -def stacked_jagged_2d_to_dense( - values: torch.Tensor, - lengths: torch.Tensor, - offset_per_key: List[int], - max_lengths_per_key: List[int], -) -> List[torch.Tensor]: - return list( - # pyre-fixme[16] - _StackedJagged2DToDenseFunction.apply( - values, - lengths, - offset_per_key, - max_lengths_per_key, - ) - ) diff --git a/fbgemm_gpu/src/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops_cpu.cpp index d2a7448964..9f2025e9a8 100644 --- a/fbgemm_gpu/src/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops_cpu.cpp @@ -1422,14 +1422,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "jagged_2d_to_dense(Tensor values, Tensor offsets, int max_sequence_length) -> Tensor"); m.def( "jagged_1d_to_dense(Tensor values, Tensor offsets, int max_sequence_length, int padding_value) -> Tensor"); - m.def( "stacked_jagged_2d_to_dense_forward(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key) -> (Tensor[], Tensor[])"); m.def( "stacked_jagged_2d_to_dense_backward(int B, int D, int total_L, Tensor[] grad_padded_values_per_key, Tensor[] offsets_tensor_per_key, int[] offset_per_key) -> Tensor"); - m.def( "stacked_jagged_1d_to_dense(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key, int padding_value) -> Tensor[]"); + m.def( + "stacked_jagged_2d_to_dense(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key) -> Tensor[]"); m.def( "histogram_binning_calibration(Tensor logit, Tensor bin_num_examples, Tensor bin_num_positives, float positive_weight, float lower_bound, float upper_bound, int bin_ctr_in_use_after, float bin_ctr_weight_value) -> (Tensor, Tensor)"); m.def( diff --git a/fbgemm_gpu/src/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops_gpu.cpp index 40e7a5da17..978723345e 100644 --- a/fbgemm_gpu/src/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops_gpu.cpp @@ -10,8 +10,10 @@ #include #include +#include #include #include +#include // for logic_error using Tensor = at::Tensor; @@ -101,6 +103,65 @@ Tensor jagged_2d_to_dense_gpu( values, offsets, static_cast(max_sequence_length))[0]; } +class StackedJagged2DToDenseGPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + Tensor values, + Tensor lengths, + const std::vector& offset_per_key, + const std::vector& max_lengths_per_key) { + int32_t total_L = values.size(0); + ctx->saved_data["B"] = lengths.size(1); + ctx->saved_data["D"] = values.size(1); + ctx->saved_data["total_L"] = total_L; + ctx->saved_data["offset_per_key"] = offset_per_key; + + auto [padded_values_per_key, offsets_tensor_per_key] = + stacked_jagged_2d_to_dense_forward_cuda( + values, lengths, offset_per_key, max_lengths_per_key); + ctx->saved_data["offsets_tensor_per_key"] = offsets_tensor_per_key; + + return padded_values_per_key; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + auto B = ctx->saved_data["B"].toInt(); + auto D = ctx->saved_data["D"].toInt(); + auto total_L = ctx->saved_data["total_L"].toInt(); + auto offset_per_key = ctx->saved_data["offset_per_key"].toIntVector(); + auto offsets_tensor_per_key = + ctx->saved_data["offsets_tensor_per_key"].toTensorVector(); + + using torch::autograd::Variable; + auto grad_values = stacked_jagged_2d_to_dense_backward_cuda( + B, D, total_L, grad_outputs, offsets_tensor_per_key, offset_per_key); + return { + grad_values, + Variable(), // lengths + Variable(), // offset_per_key + Variable() // max_lengths_per_key + }; + } +}; + +std::vector stacked_jagged_2d_to_dense_gpu( + Tensor values, + Tensor lengths, + const std::vector& offset_per_key, + const std::vector& max_lengths_per_key) { + TENSOR_ON_CUDA_GPU(values); + TENSOR_ON_CUDA_GPU(lengths); + TENSORS_ON_SAME_DEVICE(values, lengths); + TORCH_CHECK(values.dim() == 2); + TORCH_CHECK(lengths.dim() == 2); + return StackedJagged2DToDenseGPUOp::apply( + values, lengths, offset_per_key, max_lengths_per_key); +} + } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { @@ -129,6 +190,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense_gpu); DISPATCH_TO_CUDA( "stacked_jagged_1d_to_dense", fbgemm_gpu::stacked_jagged_1d_to_dense_gpu); + DISPATCH_TO_CUDA( + "stacked_jagged_2d_to_dense", fbgemm_gpu::stacked_jagged_2d_to_dense_gpu); DISPATCH_TO_CUDA( "stacked_jagged_2d_to_dense_forward", fbgemm_gpu::stacked_jagged_2d_to_dense_forward_cuda); diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 975eb664ab..9409c4d23b 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -13,12 +13,6 @@ import hypothesis.strategies as st import numpy as np import torch -from fbgemm_gpu.sparse_ops import ( - jagged_1d_to_dense, - jagged_2d_to_dense, - stacked_jagged_1d_to_dense, - stacked_jagged_2d_to_dense, -) from hypothesis import Verbosity, given, settings try: @@ -834,7 +828,7 @@ def test_jagged_2d_to_dense( values = ref_values.clone().half().detach().requires_grad_(True) else: values = ref_values.clone().detach().requires_grad_(True) - output_values = jagged_2d_to_dense( + output_values = torch.ops.fbgemm.jagged_2d_to_dense( values=values, offsets=offsets, max_sequence_length=max_sequence_length, @@ -850,7 +844,7 @@ def test_jagged_2d_to_dense( values = ref_values.clone().detach().requires_grad_(True) offsets = offsets.cuda() ref_output_values = ref_output_values.cuda() - output_values = jagged_2d_to_dense( + output_values = torch.ops.fbgemm.jagged_2d_to_dense( values=values, offsets=offsets, max_sequence_length=max_sequence_length, @@ -880,7 +874,7 @@ def test_jagged_2d_to_dense_truncation(self) -> None: # test cpu forward values = ref_values.clone().detach().requires_grad_(True) - output_values = jagged_2d_to_dense( + output_values = torch.ops.fbgemm.jagged_2d_to_dense( values=values, offsets=offsets, max_sequence_length=max_sequence_length, @@ -893,7 +887,7 @@ def test_jagged_2d_to_dense_truncation(self) -> None: values = ref_values.clone().detach().requires_grad_(True) offsets = offsets.cuda() ref_output_values = ref_output_values.cuda() - output_values = jagged_2d_to_dense( + output_values = torch.ops.fbgemm.jagged_2d_to_dense( values=values, offsets=offsets, max_sequence_length=max_sequence_length, @@ -942,14 +936,14 @@ def test_stacked_jagged_2d_to_dense( lengths = lengths.view(T, B) values = ref_values.clone().detach().requires_grad_(True) - output_values_per_table = stacked_jagged_2d_to_dense( + output_values_per_table = torch.ops.fbgemm.stacked_jagged_2d_to_dense( values=values, lengths=lengths, offset_per_key=[0] + np.cumsum([lengths[t].sum().item() for t in range(T)]).tolist(), max_lengths_per_key=[max_sequence_length] * T, ) - ref_output_values = jagged_2d_to_dense( + ref_output_values = torch.ops.fbgemm.jagged_2d_to_dense( values=ref_values, offsets=offsets, max_sequence_length=max_sequence_length, @@ -1030,7 +1024,7 @@ def var_list_to_coo( # test cpu forward values = ref_values.clone().detach().requires_grad_(False) - output_values = jagged_1d_to_dense( + output_values = torch.ops.fbgemm.jagged_1d_to_dense( values=values, offsets=offsets, max_sequence_length=max_sequence_length, @@ -1044,7 +1038,7 @@ def var_list_to_coo( values = ref_values.clone().detach().requires_grad_(False) offsets = offsets.cuda() ref_output_values = ref_output_values.cuda() - output_values = jagged_1d_to_dense( + output_values = torch.ops.fbgemm.jagged_1d_to_dense( values=values, offsets=offsets, max_sequence_length=max_sequence_length, @@ -1062,7 +1056,7 @@ def test_jagged_1d_to_dense_truncation(self) -> None: # test cpu forward values = ref_values.clone().detach().requires_grad_(False) - output = jagged_1d_to_dense( + output = torch.ops.fbgemm.jagged_1d_to_dense( values=values, offsets=offsets, max_sequence_length=1, @@ -1076,7 +1070,7 @@ def test_jagged_1d_to_dense_truncation(self) -> None: values = ref_values.clone().detach().requires_grad_(False) offsets = offsets.cuda() ref_output = ref_output.cuda() - output = jagged_1d_to_dense( + output = torch.ops.fbgemm.jagged_1d_to_dense( values=values, offsets=offsets, max_sequence_length=1, @@ -1142,7 +1136,7 @@ def var_list_to_coo( ref_values = torch.randint(low=0, high=1000000000, size=(total_lengths,)).cuda() values = ref_values.clone().detach().requires_grad_(False) - output_values_per_table = stacked_jagged_1d_to_dense( + output_values_per_table = torch.ops.fbgemm.stacked_jagged_1d_to_dense( values=values, lengths=lengths, offset_per_key=[0] @@ -1150,7 +1144,7 @@ def var_list_to_coo( max_lengths_per_key=[max_sequence_length] * T, padding_value=padding_value, ) - ref_output_values = jagged_1d_to_dense( + ref_output_values = torch.ops.fbgemm.jagged_1d_to_dense( values=ref_values, offsets=offsets, max_sequence_length=max_sequence_length,