Skip to content

Commit

Permalink
Migrate jagged_jagged_bmm_jagged_out SLL ops to OSS (pytorch#3459)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#543

Pull Request resolved: pytorch#3459

- Migrate `jagged_jagged_bmm_jagged_out` SLL ops to OSS

Reviewed By: brad-mengchi

Differential Revision: D66797818

fbshipit-source-id: 174081aa0cddf33b85428c1328bfd2d61151635e
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 6, 2024
1 parent 5a389d0 commit e2a1f8c
Show file tree
Hide file tree
Showing 6 changed files with 523 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ __configure_fbgemm_gpu_test_cpu () {
./uvm/uvm_test.py
./sll/triton_sll_test.py
./sll/array_jagged_bmm_jagged_out_test.py
./sll/jagged_jagged_bmm_jagged_out_test.py
)
}

Expand All @@ -103,7 +104,6 @@ __configure_fbgemm_gpu_test_cuda () {

ignored_tests=(
)

}

__configure_fbgemm_gpu_test_rocm () {
Expand Down
31 changes: 31 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
cpu_jagged_dense_bmm,
cpu_jagged_dense_elementwise_mul_jagged_out,
cpu_jagged_jagged_bmm,
cpu_jagged_jagged_bmm_jagged_out,
cpu_jagged_self_substraction_jagged_out,
cpu_jagged_softmax,
meta_jagged_dense_elementwise_mul_jagged_out,
Expand All @@ -28,6 +29,7 @@
from fbgemm_gpu.sll.meta_sll import ( # noqa F401
meta_array_jagged_bmm_jagged_out,
meta_jagged2_softmax,
meta_jagged_jagged_bmm_jagged_out,
)

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
Expand All @@ -38,6 +40,7 @@
jagged_dense_bmm,
jagged_dense_elementwise_mul_jagged_out,
jagged_jagged_bmm,
jagged_jagged_bmm_jagged_out,
jagged_softmax,
triton_jagged_self_substraction_jagged_out,
)
Expand Down Expand Up @@ -197,6 +200,23 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None:
"""
)

if "fbgemm::jagged_jagged_bmm_jagged_out" not in torch.library._defs:
lib.define(
"""jagged_jagged_bmm_jagged_out(
Tensor x,
Tensor y,
Tensor x_lengths,
Tensor x_offsets,
Tensor y_lengths,
Tensor y_offsets,
Tensor z_lengths,
Tensor z_offsets,
int max_seq_len,
bool allow_tf32
) -> Tensor
"""
)

# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same function
# however, this is not ideal because in the inference case, we don't need the autograd forward
# to save the context because we don't need to do backward.
Expand Down Expand Up @@ -289,3 +309,14 @@ def register_sll_op(op_name: str, functors: Dict[str, Callable]) -> None:
"AutogradMeta": meta_array_jagged_bmm_jagged_out,
},
)

register_sll_op(
"jagged_jagged_bmm_jagged_out",
{
"CUDA": jagged_jagged_bmm_jagged_out,
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
"CPU": cpu_jagged_jagged_bmm_jagged_out,
"AutogradCPU": cpu_jagged_jagged_bmm_jagged_out,
"AutogradMeta": meta_jagged_jagged_bmm_jagged_out,
},
)
131 changes: 131 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,134 @@ def cpu_array_jagged_bmm_jagged_out(
max_seq_len,
allow_tf32,
)


class JaggedJaggedBmmNoPaddingCPU(torch.autograd.Function):
"""
Compute batch matrix multiplication between JaggedTensor and JaggedTensor without padding.
z = x x y^T
x: [sum_B(M_i), D]
y: [sum_B(N_i), D]
z: [sum_B(M_i * N_i)], assuming M_i = N_i
"""

@staticmethod
# pyre-fixme
def forward(
# pyre-fixme[2]: Parameter must be annotated.
ctx,
x: torch.Tensor,
y: torch.Tensor,
x_lengths: torch.Tensor,
x_offsets: torch.Tensor,
y_lengths: torch.Tensor,
y_offsets: torch.Tensor,
z_lengths: torch.Tensor,
z_offsets: torch.Tensor,
max_seq_len: int,
# pyre-fixme[2]: Parameter must be annotated.
allow_tf32,
):
ctx.allow_tf32 = allow_tf32
ctx.max_seq_len = max_seq_len

ctx.save_for_backward(
x,
y,
x_lengths,
y_lengths,
z_lengths,
x_offsets,
y_offsets,
z_offsets,
)

return cpu_jagged_jagged_bmm_jagged_out_kernel(
x,
y,
max_seq_len,
x_lengths,
y_lengths,
z_lengths,
x_offsets,
y_offsets,
z_offsets,
allow_tf32,
)

@staticmethod
# pyre-fixme
def backward(ctx, grad_output: torch.Tensor):
"""
z = x x y^T
x: [sum_B(M_i), D]
y: [sum_B(N_i), D]
z: [sum_B(M_i * N_i)], assuming M_i = N_i
dx = dz x (y^T)^T = > dx = dz x y
d(y^T) = x^T x dz => dy = dz^T x x
"""
(
x,
y,
x_lengths,
y_lengths,
z_lengths,
x_offsets,
y_offsets,
z_offsets,
) = ctx.saved_tensors

grad_x = cpu_array_jagged_bmm_jagged_out_kernel(
grad_output,
y,
z_lengths,
y_lengths,
x_lengths,
z_offsets,
y_offsets,
x_offsets,
ctx.max_seq_len,
ctx.allow_tf32,
transpose=0,
)
grad_y = cpu_array_jagged_bmm_jagged_out_kernel(
grad_output,
x,
z_lengths,
x_lengths,
y_lengths,
z_offsets,
x_offsets,
y_offsets,
ctx.max_seq_len,
ctx.allow_tf32,
transpose=1,
)
return grad_x, grad_y, None, None, None, None, None, None, None, None


# pyre-fixme[3]: Return type must be annotated.
def cpu_jagged_jagged_bmm_jagged_out(
x: torch.Tensor,
y: torch.Tensor,
x_lengths: torch.Tensor,
x_offsets: torch.Tensor,
y_lengths: torch.Tensor,
y_offsets: torch.Tensor,
z_lengths: torch.Tensor,
z_offsets: torch.Tensor,
max_seq_len: int,
allow_tf32: bool = True,
):
return JaggedJaggedBmmNoPaddingCPU.apply(
x,
y,
x_lengths,
x_offsets,
y_lengths,
y_offsets,
z_lengths,
z_offsets,
max_seq_len,
allow_tf32,
)
85 changes: 85 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/meta_sll.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,88 @@ def meta_array_jagged_bmm_jagged_out(
max_seq_len,
allow_tf32,
)


class JaggedJaggedBmmNoPaddingMeta(torch.autograd.Function):
@staticmethod
# pyre-fixme
def forward(
# pyre-fixme[2]: Parameter must be annotated.
ctx,
x: torch.Tensor,
y: torch.Tensor,
x_lengths: torch.Tensor,
x_offsets: torch.Tensor,
y_lengths: torch.Tensor,
y_offsets: torch.Tensor,
z_lengths: torch.Tensor,
z_offsets: torch.Tensor,
max_seq_len: int,
# pyre-fixme[2]: Parameter must be annotated.
allow_tf32,
):
assert x.size(1) == y.size(0), "incompatible dimensions"

ctx.allow_tf32 = allow_tf32
ctx.max_seq_len = max_seq_len

ctx.save_for_backward(
x,
y,
x_lengths,
y_lengths,
z_lengths,
x_offsets,
y_offsets,
z_offsets,
)

# pyre-fixme[6]: For 1st argument expected `Sequence[Union[int, SymInt]]`
# but got `Tensor`.
c = torch.rand((z_lengths.sum()), device=x.device, dtype=x.dtype)
return c

@staticmethod
# pyre-fixme
def backward(ctx, grad_output: torch.Tensor):
(
x,
y,
x_lengths,
y_lengths,
z_lengths,
x_offsets,
y_offsets,
z_offsets,
) = ctx.saved_tensors

grad_x = torch.rand(x.size(), device=x.device, dtype=x.dtype)
grad_y = torch.rand(y.size(), device=y.device, dtype=y.dtype)
return grad_x, grad_y, None, None, None, None, None, None, None, None


# pyre-fixme[3]: Return type must be annotated.
def meta_jagged_jagged_bmm_jagged_out(
x: torch.Tensor,
y: torch.Tensor,
x_lengths: torch.Tensor,
x_offsets: torch.Tensor,
y_lengths: torch.Tensor,
y_offsets: torch.Tensor,
z_lengths: torch.Tensor,
z_offsets: torch.Tensor,
max_seq_len: int,
allow_tf32: bool = True,
):
return JaggedJaggedBmmNoPaddingMeta.apply(
x,
y,
x_lengths,
x_offsets,
y_lengths,
y_offsets,
z_lengths,
z_offsets,
max_seq_len,
allow_tf32,
)
Loading

0 comments on commit e2a1f8c

Please sign in to comment.