Skip to content

Commit

Permalink
[dtensor] aten.cat to use stack strategy approach (pytorch#122209)
Browse files Browse the repository at this point in the history
This PR switch aten.cat to use the strategy approach that is similar to
aten.stack, as these two ops share similar semantics

Pull Request resolved: pytorch#122209
Approved by: https://github.com/wz337
  • Loading branch information
wanchaol authored and pytorchmergebot committed Mar 20, 2024
1 parent 5b7ceab commit 11e64b4
Showing 1 changed file with 39 additions and 142 deletions.
181 changes: 39 additions & 142 deletions torch/distributed/_tensor/ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from torch.distributed._tensor._utils import compute_local_shape
from torch.distributed._tensor.op_schema import (
OpSchema,
OpStrategy,
Expand All @@ -22,7 +21,6 @@
is_tensor_partial,
is_tensor_shardable,
normalize_dim,
prod,
register_op_strategy,
register_prop_rule,
)
Expand Down Expand Up @@ -478,7 +476,12 @@ def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
args_schema = op_schema.args_schema
input_tuple_strategy = args_schema[0]
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
first_input_strategy = input_tuple_strategy.childs[0]
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
common_input_ndim = first_input_strategy.output_ndim
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
# normalize the dim to be within the common input ndim
dim = normalize_dim(dim, common_input_ndim)

follow_placements = _derive_follow_placements_from_tuple_strategy(
input_tuple_strategy
Expand All @@ -501,6 +504,40 @@ def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
return op_strategy


@register_op_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True))
def cat_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
args_schema = op_schema.args_schema
input_tuple_strategy = args_schema[0]
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
first_input_strategy = input_tuple_strategy.childs[0]
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
common_input_ndim = first_input_strategy.output_ndim
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
# normalize the dim to be within the common input ndim
dim = normalize_dim(dim, common_input_ndim)

follow_placements = _derive_follow_placements_from_tuple_strategy(
input_tuple_strategy
)
# for cat we unshard the cat dim if it is sharded
follow_placements = unshard_tensor_dim(follow_placements, dim)

# create op strategy base on the follow placements
op_strategy = OpStrategy([])

input_specs = tuple(
DTensorSpec(mesh, tuple(follow_placements))
for _ in range(len(input_tuple_strategy.childs))
)
op_strategy.strategies.append(
PlacementStrategy(
output_specs=DTensorSpec(mesh, tuple(follow_placements)),
input_specs=input_specs,
)
)
return op_strategy


@register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1))
def prop_index_select(op_schema: OpSchema) -> OutputSharding:
values_spec, dim, indices_spec = op_schema.args_schema
Expand Down Expand Up @@ -656,146 +693,6 @@ def place(vp: Placement, ip: Placement) -> Placement:
return result


@register_prop_rule(
aten.cat.default, schema_info=RuntimeSchemaInfo(1, needs_pytree=True)
)
def cat_rule(op_schema: OpSchema) -> OutputSharding:
# torch.cat requires all tensors must either have the same shape (except
# in the concatenating dimension) or be "empty". "Empty" here strictly means
# tensor.shape is torch.Size([0]). When tensor.ndim > 1, it will be treated
# as a non-empty tensor and the shape must match on non-cat dimensions.
def is_empty(spec: DTensorSpec) -> bool:
return list(spec.shape) == [0]

# the first arg is a list of input tensor specs
tensor_list_specs = cast(List[DTensorSpec], op_schema.args_schema[0])
assert len(tensor_list_specs) > 0, "torch.cat expects a non-empty list of tensors"
non_empty_specs = [spec for spec in tensor_list_specs if not is_empty(spec)]

if len(non_empty_specs) == 0:
# all tensors are empty, we can return any output sharding
return OutputSharding(
output_spec=DTensorSpec(
mesh=tensor_list_specs[0].mesh,
placements=tensor_list_specs[0].placements,
)
)

assert all(
spec.ndim == non_empty_specs[0].ndim for spec in non_empty_specs
), f"Expect all tensors to have same shape or empty, but got {tensor_list_specs}"
assert all(
spec.mesh == tensor_list_specs[0].mesh for spec in tensor_list_specs
), f"Expect all tensors to have same mesh, but got {tensor_list_specs}"

# ndim will also be the result's ndim
ndim = 1
for spec in tensor_list_specs:
ndim = max(ndim, spec.ndim)

dim = 0 # default dim = 0
if len(op_schema.args_schema) > 1:
dim = cast(int, op_schema.args_schema[1])
dim = normalize_dim(dim, ndim)

# Make sure all tensors are replicated on cat dimension
need_reshard = False
tensor_list_specs_after: List[DTensorSpec] = []
for spec in tensor_list_specs:
if not is_empty(spec) and (
is_tensor_dim_sharded(spec, dim=dim) or is_tensor_partial(spec)
):
need_reshard = True
tensor_list_specs_after.append(
DTensorSpec(
mesh=spec.mesh,
placements=replicate_tensor_dim(spec.placements, dim=dim),
tensor_meta=spec.tensor_meta,
)
)
else:
tensor_list_specs_after.append(spec)

tensor_list_specs = tensor_list_specs_after

# align non-cat dimensions placements based on reshard cost
non_empty_specs = [spec for spec in tensor_list_specs if not is_empty(spec)]
mesh = non_empty_specs[0].mesh
ndim = non_empty_specs[0].ndim
new_placements: List[Placement] = []
for mesh_dim in range(mesh.ndim):
# compute the minimum cost of resharding on this mesh_dim
if any(
spec.placements[mesh_dim] != non_empty_specs[0].placements[mesh_dim]
for spec in non_empty_specs
):
# only reshard if there is a mismatch
need_reshard = True
reshard_cost = []
for shard_dim in range(ndim):
# compute the cost of resharding on this shard_dim
cost: float = 0.0
for spec in non_empty_specs:
global_shape = spec.shape
if global_shape[shard_dim] < mesh.size(mesh_dim):
# found one tensor where the shard_dim is smaller than
# mesh_dim. In this case, we cannot shard on this shard_dim,
# and hence set cost to infinity.
cost = +float("inf")
elif (
is_tensor_dim_sharded(spec, dim=shard_dim)
or prod(global_shape) == 0
):
continue
else:
local_shape = compute_local_shape(
global_shape, spec.mesh, spec.placements
)
cost += prod(local_shape) * spec.mesh.size(mesh_dim)
reshard_cost.append(cost)
best_dim = reshard_cost.index(min(reshard_cost))
new_placements.append(Shard(best_dim))
else:
# no mismatch, keep the original placement
new_placements.append(non_empty_specs[0].placements[mesh_dim])

if need_reshard:
tensor_list_specs_after = []
for spec in tensor_list_specs:
if is_empty(spec):
tensor_list_specs_after.append(spec)
else:
tensor_list_specs_after.append(
DTensorSpec(
mesh=spec.mesh,
placements=tuple(new_placements),
tensor_meta=spec.tensor_meta,
)
)

return OutputSharding(
output_spec=None,
schema_suggestions=[
OpSchema(
op=op_schema.op,
args_schema=(
tuple(tensor_list_specs_after),
*op_schema.args_schema[1:],
),
kwargs_schema=op_schema.kwargs_schema,
),
],
)
else:
# at this point, the cat dim is not sharded,
return OutputSharding(
output_spec=DTensorSpec(
mesh=non_empty_specs[0].mesh,
placements=non_empty_specs[0].placements,
),
)


@register_prop_rule(
[
aten.split.Tensor,
Expand Down

0 comments on commit 11e64b4

Please sign in to comment.