Skip to content

Commit

Permalink
remove wait in all_to_all_single custom op (pytorch#2646)
Browse files Browse the repository at this point in the history
Summary:

# context
* remove the `torch.ops._c10d_functional.wait_tensor` call in all_to_all_single.
* use `autograd.function` implementation to create a `AllToAllSingle` function
* There is a wait_tensor after the all_to_all_single call: pytorch/pytorch#143533

Differential Revision: D64666999
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Dec 19, 2024
1 parent b1bd136 commit dd1fcdb
Showing 1 changed file with 38 additions and 67 deletions.
105 changes: 38 additions & 67 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def all2all_pooled_sync(
qcomm_ctx = None

with record_function("## alltoall_fwd_single ##"):
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_output_embeddings = AllToAllSingle.apply(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
Expand Down Expand Up @@ -572,7 +572,7 @@ def variable_batch_all2all_pooled_sync(
torch._check(s0 <= sharded_input_embeddings.size(0))
sharded_output_embeddings.copy_(sharded_input_embeddings[:s0])
else:
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_output_embeddings = AllToAllSingle.apply(
sharded_input_embeddings,
output_split_sizes,
input_split_sizes,
Expand Down Expand Up @@ -722,7 +722,7 @@ def all2all_sequence_sync(
qcomm_ctx = None

with record_function("## alltoall_seq_embedding_fwd_single ##"):
sharded_output_embeddings = torch.ops.torchrec.all_to_all_single(
sharded_output_embeddings = AllToAllSingle.apply(
sharded_input_embeddings,
output_splits,
input_splits,
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def reduce_scatter_v_sync(
input_splits = rsi.input_splits
output_splits = [rsi.input_splits[rank]] * world_size
# TODO(ivankobzarev): Replace with _functional_collectives.reduce_scatter_v when it is added
a2a_output = torch.ops.torchrec.all_to_all_single(
a2a_output = AllToAllSingle.apply(
input,
output_splits,
input_splits,
Expand Down Expand Up @@ -2348,71 +2348,42 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:

if not torch._running_with_deploy(): # noqa C901
# Torch Library op def can not be used in Deploy
class AllToAllSingle(torch.autograd.Function):
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
# pyre-fixme[2]: Parameter must be annotated.
ctx,
input: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
group_name: str,
group_size: int,
gradient_division: bool,
) -> Tensor:
ctx.output_split_sizes = input_split_sizes
ctx.input_split_sizes = output_split_sizes
ctx.group_name = group_name
ctx.group_size = group_size
ctx.gradient_division = gradient_division
return torch.distributed._functional_collectives.all_to_all_single(
input, output_split_sizes, input_split_sizes, group_name
)

# torchrec::all_to_all_single
@torch.library.custom_op("torchrec::all_to_all_single", mutates_args=())
def all_to_all_single(
input: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
group_name: str,
group_size: int,
gradient_division: bool,
) -> Tensor:
out = torch.ops._c10d_functional.all_to_all_single(
input, output_split_sizes, input_split_sizes, group_name
)
return torch.ops._c10d_functional.wait_tensor(out)

@torch.library.register_fake("torchrec::all_to_all_single")
def all_to_all_single_fake(
input: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
group_name: str,
group_size: int,
gradient_division: bool,
) -> Tensor:
return torch.ops._c10d_functional.all_to_all_single(
input, output_split_sizes, input_split_sizes, group_name
)

# pyre-ignore
def all_to_all_single_setup_context(ctx, inputs, output) -> None:
(
_,
output_split_sizes,
input_split_sizes,
group_name,
group_size,
gradient_division,
) = inputs
ctx.output_split_sizes = input_split_sizes
ctx.input_split_sizes = output_split_sizes
ctx.group_name = group_name
ctx.group_size = group_size
ctx.gradient_division = gradient_division

# pyre-ignore
def all_to_all_single_backward(ctx, grad):
# TODO(ivankobzarev): Support codecs(quantization) on backward
a2a_out = torch.ops._c10d_functional.all_to_all_single(
grad,
ctx.output_split_sizes,
ctx.input_split_sizes,
ctx.group_name,
)
grad = torch.ops._c10d_functional.wait_tensor(a2a_out)
if ctx.gradient_division:
grad.div_(ctx.group_size)

return grad, None, None, None, None, None
@staticmethod
# pyre-ignore
def backward(ctx, grad):
# TODO(ivankobzarev): Support codecs(quantization) on backward
grad = torch.distributed._functional_collectives.all_to_all_single(
grad,
ctx.output_split_sizes,
ctx.input_split_sizes,
ctx.group_name,
)
if ctx.gradient_division:
grad.div_(ctx.group_size)

torch.library.register_autograd(
"torchrec::all_to_all_single",
all_to_all_single_backward,
setup_context=all_to_all_single_setup_context,
)
return grad, None, None, None, None, None

# torchrec::reduce_scatter_tensor
@torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=())
Expand Down

0 comments on commit dd1fcdb

Please sign in to comment.