diff --git a/.github/scripts/fbgemm_gpu_test.bash b/.github/scripts/fbgemm_gpu_test.bash index 15acc26e2a..05e3a3f375 100644 --- a/.github/scripts/fbgemm_gpu_test.bash +++ b/.github/scripts/fbgemm_gpu_test.bash @@ -85,6 +85,7 @@ __configure_fbgemm_gpu_test_cpu () { ./uvm/uvm_test.py ./sll/triton_sll_test.py ./sll/array_jagged_bmm_jagged_out_test.py + ./sll/jagged_jagged_bmm_jagged_out_test.py ) } @@ -103,7 +104,6 @@ __configure_fbgemm_gpu_test_cuda () { ignored_tests=( ) - } __configure_fbgemm_gpu_test_rocm () { diff --git a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py index 4ee2888b8c..bfd9c2eb8d 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py @@ -19,6 +19,7 @@ cpu_jagged_dense_bmm, cpu_jagged_dense_elementwise_mul_jagged_out, cpu_jagged_jagged_bmm, + cpu_jagged_jagged_bmm_jagged_out, cpu_jagged_self_substraction_jagged_out, cpu_jagged_softmax, meta_jagged_dense_elementwise_mul_jagged_out, @@ -28,6 +29,7 @@ from fbgemm_gpu.sll.meta_sll import ( # noqa F401 meta_array_jagged_bmm_jagged_out, meta_jagged2_softmax, + meta_jagged_jagged_bmm_jagged_out, ) from fbgemm_gpu.sll.triton_sll import ( # noqa F401 @@ -38,6 +40,7 @@ jagged_dense_bmm, jagged_dense_elementwise_mul_jagged_out, jagged_jagged_bmm, + jagged_jagged_bmm_jagged_out, jagged_softmax, triton_jagged_self_substraction_jagged_out, ) @@ -197,6 +200,23 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None: """ ) +if "fbgemm::jagged_jagged_bmm_jagged_out" not in torch.library._defs: + lib.define( + """jagged_jagged_bmm_jagged_out( + Tensor x, + Tensor y, + Tensor x_lengths, + Tensor x_offsets, + Tensor y_lengths, + Tensor y_offsets, + Tensor z_lengths, + Tensor z_offsets, + int max_seq_len, + bool allow_tf32 + ) -> Tensor + """ + ) + # NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same function # however, this is not ideal because in the inference case, we don't need the autograd forward # to save the context because we don't need to do backward. @@ -289,3 +309,14 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None: "AutogradMeta": meta_array_jagged_bmm_jagged_out, }, ) + +register_sll_op( + "jagged_jagged_bmm_jagged_out", + { + "CUDA": jagged_jagged_bmm_jagged_out, + "AutogradCUDA": jagged_jagged_bmm_jagged_out, + "CPU": cpu_jagged_jagged_bmm_jagged_out, + "AutogradCPU": cpu_jagged_jagged_bmm_jagged_out, + "AutogradMeta": meta_jagged_jagged_bmm_jagged_out, + }, +) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py index 2a176b07b3..a4dd8e66f3 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py @@ -764,3 +764,134 @@ def cpu_array_jagged_bmm_jagged_out( max_seq_len, allow_tf32, ) + + +class JaggedJaggedBmmNoPaddingCPU(torch.autograd.Function): + """ + Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding. + z = x x y^T + x: [sum_B(M_i), D] + y: [sum_B(N_i), D] + z: [sum_B(M_i * N_i)], assuming M_i = N_i + """ + + @staticmethod + # pyre-fixme + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + x: torch.Tensor, + y: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + y_lengths: torch.Tensor, + y_offsets: torch.Tensor, + z_lengths: torch.Tensor, + z_offsets: torch.Tensor, + max_seq_len: int, + # pyre-fixme[2]: Parameter must be annotated. + allow_tf32, + ): + ctx.allow_tf32 = allow_tf32 + ctx.max_seq_len = max_seq_len + + ctx.save_for_backward( + x, + y, + x_lengths, + y_lengths, + z_lengths, + x_offsets, + y_offsets, + z_offsets, + ) + + return cpu_jagged_jagged_bmm_jagged_out_kernel( + x, + y, + max_seq_len, + x_lengths, + y_lengths, + z_lengths, + x_offsets, + y_offsets, + z_offsets, + allow_tf32, + ) + + @staticmethod + # pyre-fixme + def backward(ctx, grad_output: torch.Tensor): + """ + z = x x y^T + x: [sum_B(M_i), D] + y: [sum_B(N_i), D] + z: [sum_B(M_i * N_i)], assuming M_i = N_i + dx = dz x (y^T)^T = > dx = dz x y + d(y^T) = x^T x dz => dy = dz^T x x + """ + ( + x, + y, + x_lengths, + y_lengths, + z_lengths, + x_offsets, + y_offsets, + z_offsets, + ) = ctx.saved_tensors + + grad_x = cpu_array_jagged_bmm_jagged_out_kernel( + grad_output, + y, + z_lengths, + y_lengths, + x_lengths, + z_offsets, + y_offsets, + x_offsets, + ctx.max_seq_len, + ctx.allow_tf32, + transpose=0, + ) + grad_y = cpu_array_jagged_bmm_jagged_out_kernel( + grad_output, + x, + z_lengths, + x_lengths, + y_lengths, + z_offsets, + x_offsets, + y_offsets, + ctx.max_seq_len, + ctx.allow_tf32, + transpose=1, + ) + return grad_x, grad_y, None, None, None, None, None, None, None, None + + +# pyre-fixme[3]: Return type must be annotated. +def cpu_jagged_jagged_bmm_jagged_out( + x: torch.Tensor, + y: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + y_lengths: torch.Tensor, + y_offsets: torch.Tensor, + z_lengths: torch.Tensor, + z_offsets: torch.Tensor, + max_seq_len: int, + allow_tf32: bool = True, +): + return JaggedJaggedBmmNoPaddingCPU.apply( + x, + y, + x_lengths, + x_offsets, + y_lengths, + y_offsets, + z_lengths, + z_offsets, + max_seq_len, + allow_tf32, + ) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py index 0ee298b50c..924f13b260 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py @@ -178,3 +178,88 @@ def meta_array_jagged_bmm_jagged_out( max_seq_len, allow_tf32, ) + + +class JaggedJaggedBmmNoPaddingMeta(torch.autograd.Function): + @staticmethod + # pyre-fixme + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + x: torch.Tensor, + y: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + y_lengths: torch.Tensor, + y_offsets: torch.Tensor, + z_lengths: torch.Tensor, + z_offsets: torch.Tensor, + max_seq_len: int, + # pyre-fixme[2]: Parameter must be annotated. + allow_tf32, + ): + assert x.size(1) == y.size(0), "incompatible dimensions" + + ctx.allow_tf32 = allow_tf32 + ctx.max_seq_len = max_seq_len + + ctx.save_for_backward( + x, + y, + x_lengths, + y_lengths, + z_lengths, + x_offsets, + y_offsets, + z_offsets, + ) + + # pyre-fixme[6]: For 1st argument expected `Sequence[Union[int, SymInt]]` + # but got `Tensor`. + c = torch.rand((z_lengths.sum()), device=x.device, dtype=x.dtype) + return c + + @staticmethod + # pyre-fixme + def backward(ctx, grad_output: torch.Tensor): + ( + x, + y, + x_lengths, + y_lengths, + z_lengths, + x_offsets, + y_offsets, + z_offsets, + ) = ctx.saved_tensors + + grad_x = torch.rand(x.size(), device=x.device, dtype=x.dtype) + grad_y = torch.rand(y.size(), device=y.device, dtype=y.dtype) + return grad_x, grad_y, None, None, None, None, None, None, None, None + + +# pyre-fixme[3]: Return type must be annotated. +def meta_jagged_jagged_bmm_jagged_out( + x: torch.Tensor, + y: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + y_lengths: torch.Tensor, + y_offsets: torch.Tensor, + z_lengths: torch.Tensor, + z_offsets: torch.Tensor, + max_seq_len: int, + allow_tf32: bool = True, +): + return JaggedJaggedBmmNoPaddingMeta.apply( + x, + y, + x_lengths, + x_offsets, + y_lengths, + y_offsets, + z_lengths, + z_offsets, + max_seq_len, + allow_tf32, + ) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py index d1408848ab..04a77b5d3e 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py @@ -1249,6 +1249,108 @@ def backward(ctx, grad_output: torch.Tensor): return grad_x, grad_y, None, None, None, None, None, None, None, None +class JaggedJaggedBmmNoPadding(torch.autograd.Function): + """ + Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding. + z = x x y^T + x: [sum_B(M_i), D] + y: [sum_B(N_i), D] + z: [sum_B(M_i * N_i)], assuming M_i = N_i + """ + + @staticmethod + # pyre-fixme + def forward( + ctx, + x: torch.Tensor, + y: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + y_lengths: torch.Tensor, + y_offsets: torch.Tensor, + z_lengths: torch.Tensor, + z_offsets: torch.Tensor, + max_seq_len: int, + allow_tf32, + ): + ctx.allow_tf32 = allow_tf32 + ctx.max_seq_len = max_seq_len + + ctx.save_for_backward( + x, + y, + x_lengths, + y_lengths, + z_lengths, + x_offsets, + y_offsets, + z_offsets, + ) + + return triton_jagged_jagged_bmm_jagged_out( + x, + y.T, + max_seq_len, + x_lengths, + y_lengths, + z_lengths, + x_offsets, + y_offsets, + z_offsets, + allow_tf32, + ) + + @staticmethod + # pyre-fixme + def backward(ctx, grad_output: torch.Tensor): + """ + z = x x y^T + x: [sum_B(M_i), D] + y: [sum_B(N_i), D] + z: [sum_B(M_i * N_i)], assuming M_i = N_i + dx = dz x (y^T)^T = > dx = dz x y + d(y^T) = x^T x dz => dy = dz^T x x + """ + ( + x, + y, + x_lengths, + y_lengths, + z_lengths, + x_offsets, + y_offsets, + z_offsets, + ) = ctx.saved_tensors + + grad_x = triton_array_jagged_bmm_jagged_out( + grad_output, + y, + z_lengths, + y_lengths, + x_lengths, + z_offsets, + y_offsets, + x_offsets, + ctx.max_seq_len, + ctx.allow_tf32, + transpose=0, + ) + grad_y = triton_array_jagged_bmm_jagged_out( + grad_output, + x, + z_lengths, + x_lengths, + y_lengths, + z_offsets, + x_offsets, + y_offsets, + ctx.max_seq_len, + ctx.allow_tf32, + transpose=1, + ) + return grad_x, grad_y, None, None, None, None, None, None, None, None + + def jagged_dense_bmm( x: torch.Tensor, y: torch.Tensor, @@ -1798,3 +1900,29 @@ def array_jagged_bmm_jagged_out( max_seq_len, allow_tf32, ) + + +def jagged_jagged_bmm_jagged_out( + x: torch.Tensor, + y: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + y_lengths: torch.Tensor, + y_offsets: torch.Tensor, + z_lengths: torch.Tensor, + z_offsets: torch.Tensor, + max_seq_len: int, + allow_tf32: bool = True, +): + return JaggedJaggedBmmNoPadding.apply( + x, + y, + x_lengths, + x_offsets, + y_lengths, + y_offsets, + z_lengths, + z_offsets, + max_seq_len, + allow_tf32, + ) diff --git a/fbgemm_gpu/test/sll/jagged_jagged_bmm_jagged_out_test.py b/fbgemm_gpu/test/sll/jagged_jagged_bmm_jagged_out_test.py new file mode 100644 index 0000000000..ae46cf1510 --- /dev/null +++ b/fbgemm_gpu/test/sll/jagged_jagged_bmm_jagged_out_test.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import hypothesis.strategies as st +import torch +from fbgemm_gpu.sll.cpu_sll import cpu_jagged_jagged_bmm_jagged_out # noqa +from fbgemm_gpu.sll.meta_sll import meta_jagged_jagged_bmm_jagged_out # noqa +from fbgemm_gpu.sll.triton_sll import triton_jagged_jagged_bmm_jagged_out +from hypothesis import given, settings + +from .common import open_source # noqa + + +class JaggedJaggedBmmJaggedOutTest(unittest.TestCase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @given( + B=st.integers(10, 512), + max_L=st.integers(1, 200), + K=st.integers(1, 100), + ) + @settings(deadline=20000) + def test_triton_jagged_jagged_bmm_jagged_out( + self, + B: int, + max_L: int, + K: int, + ) -> None: + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + + lengths_m = torch.randint(1, max_L + 1, (B,)).cuda() + lengths_n = lengths_m + + offsets_m = torch.cat( + [torch.IntTensor([0]).cuda(), lengths_m.cumsum(dim=0)], dim=0 + ) + offsets_n = torch.cat( + [torch.IntTensor([0]).cuda(), lengths_n.cumsum(dim=0)], dim=0 + ) + lengths_mn = lengths_m * lengths_n + offsets_mn = torch.cat( + [torch.IntTensor([0]).cuda(), lengths_mn.cumsum(dim=0)], dim=0 + ) + + jagged_A = torch.rand(int(lengths_m.sum().item()), K).cuda() + jagged_B = torch.rand(int(lengths_n.sum().item()), K).cuda() + + def ref_jagged_jagged_bmm_jagged_out( + B: int, + jagged_A: torch.Tensor, + jagged_B: torch.Tensor, + lengths_mn: torch.Tensor, + offsets_mn: torch.Tensor, + offsets_m: torch.Tensor, + offsets_n: torch.Tensor, + ) -> torch.Tensor: + jagged_C = torch.empty( + (int(lengths_mn.sum().item())), dtype=jagged_A.dtype + ).to(jagged_A.device) + + for i in range(B): + jagged_C[offsets_mn[i] : offsets_mn[i + 1]] = torch.matmul( + jagged_A[offsets_m[i] : offsets_m[i + 1]], + jagged_B[offsets_n[i] : offsets_n[i + 1]].T, + ).flatten() + return jagged_C + + jagged_C_ref = ref_jagged_jagged_bmm_jagged_out( + B, jagged_A, jagged_B, lengths_mn, offsets_mn, offsets_m, offsets_n + ) + jagged_C_test = triton_jagged_jagged_bmm_jagged_out( + jagged_A, + jagged_B.T, + max_L, + lengths_m, + lengths_n, + lengths_mn, + offsets_m, + offsets_n, + offsets_mn, + allow_tf32=False, + ) + + assert torch.allclose(jagged_C_ref, jagged_C_test) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.integers(10, 512)` to decorator factory + # `hypothesis.given`. + @given( + B=st.integers(10, 512), + max_L=st.integers(1, 200), + K=st.integers(1, 100), + device_type=st.sampled_from(["meta"]), + ) + @settings(deadline=20000) + def test_triton_jagged_jagged_bmm_jagged_out_meta_backend( + self, + B: int, + max_L: int, + K: int, + device_type: str, + ) -> None: + lengths_m = torch.randint(1, max_L + 1, (B,)) + lengths_n = lengths_m + device = torch.device(device_type) + + offsets_m = torch.cat([torch.IntTensor([0]), lengths_m.cumsum(dim=0)], dim=0) + offsets_n = torch.cat([torch.IntTensor([0]), lengths_n.cumsum(dim=0)], dim=0) + lengths_mn = lengths_m * lengths_n + offsets_mn = torch.cat([torch.IntTensor([0]), lengths_mn.cumsum(dim=0)], dim=0) + + jagged_A = torch.rand( + int(lengths_m.sum().item()), K, requires_grad=True, device=device + ) + jagged_B = torch.rand( + int(lengths_n.sum().item()), K, requires_grad=True, device=device + ) + + jagged_C_ref = torch.rand(int(lengths_mn.sum().item()), device=device) + jagged_C_test = torch.ops.fbgemm.jagged_jagged_bmm_jagged_out( + jagged_A, + jagged_B.T, + lengths_m.to(device_type), + offsets_m.to(device_type), + lengths_n.to(device_type), + offsets_n.to(device_type), + lengths_mn, + offsets_mn, + max_L, + allow_tf32=False, + ) + assert jagged_C_test.is_meta and jagged_C_ref.size() == jagged_C_test.size() + + grad_output = torch.rand((jagged_C_test.shape), device=device_type) * 0.01 + jagged_C_test.backward(grad_output) + + # pyre-fixme[16]: Optional type has no attribute `is_meta`. + # pyre-fixme[16]: Optional type has no attribute `size`. + assert jagged_A.grad.is_meta and jagged_A.grad.size() == jagged_A.size() + assert jagged_B.grad.is_meta and jagged_B.grad.size() == jagged_B.size()