diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 856f50d10..0d398e576 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -443,7 +443,7 @@ def all2all_pooled_sync( qcomm_ctx = None with record_function("## alltoall_fwd_single ##"): - sharded_output_embeddings = torch.ops.torchrec.all_to_all_single( + sharded_output_embeddings = AllToAllSingle.apply( sharded_input_embeddings, output_split_sizes, input_split_sizes, @@ -572,7 +572,7 @@ def variable_batch_all2all_pooled_sync( torch._check(s0 <= sharded_input_embeddings.size(0)) sharded_output_embeddings.copy_(sharded_input_embeddings[:s0]) else: - sharded_output_embeddings = torch.ops.torchrec.all_to_all_single( + sharded_output_embeddings = AllToAllSingle.apply( sharded_input_embeddings, output_split_sizes, input_split_sizes, @@ -722,7 +722,7 @@ def all2all_sequence_sync( qcomm_ctx = None with record_function("## alltoall_seq_embedding_fwd_single ##"): - sharded_output_embeddings = torch.ops.torchrec.all_to_all_single( + sharded_output_embeddings = AllToAllSingle.apply( sharded_input_embeddings, output_splits, input_splits, @@ -1004,7 +1004,7 @@ def reduce_scatter_v_sync( input_splits = rsi.input_splits output_splits = [rsi.input_splits[rank]] * world_size # TODO(ivankobzarev): Replace with _functional_collectives.reduce_scatter_v when it is added - a2a_output = torch.ops.torchrec.all_to_all_single( + a2a_output = AllToAllSingle.apply( input, output_splits, input_splits, @@ -2348,6 +2348,42 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: if not torch._running_with_deploy(): # noqa C901 # Torch Library op def can not be used in Deploy + class AllToAllSingle(torch.autograd.Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + group_name: str, + group_size: int, + gradient_division: bool, + ) -> Tensor: + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes + ctx.group_name = group_name + ctx.group_size = group_size + ctx.gradient_division = gradient_division + return torch.distributed._functional_collectives.all_to_all_single( + input, output_split_sizes, input_split_sizes, group_name + ) + + @staticmethod + # pyre-ignore + def backward(ctx, grad): + # TODO(ivankobzarev): Support codecs(quantization) on backward + grad = torch.distributed._functional_collectives.all_to_all_single( + grad, + ctx.output_split_sizes, + ctx.input_split_sizes, + ctx.group_name, + ) + if ctx.gradient_division: + grad.div_(ctx.group_size) + + return grad, None, None, None, None, None # torchrec::all_to_all_single @torch.library.custom_op("torchrec::all_to_all_single", mutates_args=()) @@ -2359,10 +2395,14 @@ def all_to_all_single( group_size: int, gradient_division: bool, ) -> Tensor: - out = torch.ops._c10d_functional.all_to_all_single( - input, output_split_sizes, input_split_sizes, group_name + return AllToAllSingle.apply( + input, + output_split_sizes, + input_split_sizes, + group_name, + group_size, + gradient_division, ) - return torch.ops._c10d_functional.wait_tensor(out) @torch.library.register_fake("torchrec::all_to_all_single") def all_to_all_single_fake( @@ -2377,43 +2417,6 @@ def all_to_all_single_fake( input, output_split_sizes, input_split_sizes, group_name ) - # pyre-ignore - def all_to_all_single_setup_context(ctx, inputs, output) -> None: - ( - _, - output_split_sizes, - input_split_sizes, - group_name, - group_size, - gradient_division, - ) = inputs - ctx.output_split_sizes = input_split_sizes - ctx.input_split_sizes = output_split_sizes - ctx.group_name = group_name - ctx.group_size = group_size - ctx.gradient_division = gradient_division - - # pyre-ignore - def all_to_all_single_backward(ctx, grad): - # TODO(ivankobzarev): Support codecs(quantization) on backward - a2a_out = torch.ops._c10d_functional.all_to_all_single( - grad, - ctx.output_split_sizes, - ctx.input_split_sizes, - ctx.group_name, - ) - grad = torch.ops._c10d_functional.wait_tensor(a2a_out) - if ctx.gradient_division: - grad.div_(ctx.group_size) - - return grad, None, None, None, None, None - - torch.library.register_autograd( - "torchrec::all_to_all_single", - all_to_all_single_backward, - setup_context=all_to_all_single_setup_context, - ) - # torchrec::reduce_scatter_tensor @torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=()) def reduce_scatter_tensor(