From 11e64b4ba8a43e2b572e02e64280eab3b4675d4f Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 19 Mar 2024 15:33:02 -0700 Subject: [PATCH] [dtensor] aten.cat to use stack strategy approach (#122209) This PR switch aten.cat to use the strategy approach that is similar to aten.stack, as these two ops share similar semantics Pull Request resolved: https://github.com/pytorch/pytorch/pull/122209 Approved by: https://github.com/wz337 --- torch/distributed/_tensor/ops/tensor_ops.py | 181 +++++--------------- 1 file changed, 39 insertions(+), 142 deletions(-) diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index 71879eb88a9404..0fce2fcd005d34 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -4,7 +4,6 @@ import torch -from torch.distributed._tensor._utils import compute_local_shape from torch.distributed._tensor.op_schema import ( OpSchema, OpStrategy, @@ -22,7 +21,6 @@ is_tensor_partial, is_tensor_shardable, normalize_dim, - prod, register_op_strategy, register_prop_rule, ) @@ -478,7 +476,12 @@ def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: args_schema = op_schema.args_schema input_tuple_strategy = args_schema[0] assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" + first_input_strategy = input_tuple_strategy.childs[0] + assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" + common_input_ndim = first_input_strategy.output_ndim dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 + # normalize the dim to be within the common input ndim + dim = normalize_dim(dim, common_input_ndim) follow_placements = _derive_follow_placements_from_tuple_strategy( input_tuple_strategy @@ -501,6 +504,40 @@ def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: return op_strategy +@register_op_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True)) +def cat_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" + first_input_strategy = input_tuple_strategy.childs[0] + assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" + common_input_ndim = first_input_strategy.output_ndim + dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 + # normalize the dim to be within the common input ndim + dim = normalize_dim(dim, common_input_ndim) + + follow_placements = _derive_follow_placements_from_tuple_strategy( + input_tuple_strategy + ) + # for cat we unshard the cat dim if it is sharded + follow_placements = unshard_tensor_dim(follow_placements, dim) + + # create op strategy base on the follow placements + op_strategy = OpStrategy([]) + + input_specs = tuple( + DTensorSpec(mesh, tuple(follow_placements)) + for _ in range(len(input_tuple_strategy.childs)) + ) + op_strategy.strategies.append( + PlacementStrategy( + output_specs=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) + return op_strategy + + @register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1)) def prop_index_select(op_schema: OpSchema) -> OutputSharding: values_spec, dim, indices_spec = op_schema.args_schema @@ -656,146 +693,6 @@ def place(vp: Placement, ip: Placement) -> Placement: return result -@register_prop_rule( - aten.cat.default, schema_info=RuntimeSchemaInfo(1, needs_pytree=True) -) -def cat_rule(op_schema: OpSchema) -> OutputSharding: - # torch.cat requires all tensors must either have the same shape (except - # in the concatenating dimension) or be "empty". "Empty" here strictly means - # tensor.shape is torch.Size([0]). When tensor.ndim > 1, it will be treated - # as a non-empty tensor and the shape must match on non-cat dimensions. - def is_empty(spec: DTensorSpec) -> bool: - return list(spec.shape) == [0] - - # the first arg is a list of input tensor specs - tensor_list_specs = cast(List[DTensorSpec], op_schema.args_schema[0]) - assert len(tensor_list_specs) > 0, "torch.cat expects a non-empty list of tensors" - non_empty_specs = [spec for spec in tensor_list_specs if not is_empty(spec)] - - if len(non_empty_specs) == 0: - # all tensors are empty, we can return any output sharding - return OutputSharding( - output_spec=DTensorSpec( - mesh=tensor_list_specs[0].mesh, - placements=tensor_list_specs[0].placements, - ) - ) - - assert all( - spec.ndim == non_empty_specs[0].ndim for spec in non_empty_specs - ), f"Expect all tensors to have same shape or empty, but got {tensor_list_specs}" - assert all( - spec.mesh == tensor_list_specs[0].mesh for spec in tensor_list_specs - ), f"Expect all tensors to have same mesh, but got {tensor_list_specs}" - - # ndim will also be the result's ndim - ndim = 1 - for spec in tensor_list_specs: - ndim = max(ndim, spec.ndim) - - dim = 0 # default dim = 0 - if len(op_schema.args_schema) > 1: - dim = cast(int, op_schema.args_schema[1]) - dim = normalize_dim(dim, ndim) - - # Make sure all tensors are replicated on cat dimension - need_reshard = False - tensor_list_specs_after: List[DTensorSpec] = [] - for spec in tensor_list_specs: - if not is_empty(spec) and ( - is_tensor_dim_sharded(spec, dim=dim) or is_tensor_partial(spec) - ): - need_reshard = True - tensor_list_specs_after.append( - DTensorSpec( - mesh=spec.mesh, - placements=replicate_tensor_dim(spec.placements, dim=dim), - tensor_meta=spec.tensor_meta, - ) - ) - else: - tensor_list_specs_after.append(spec) - - tensor_list_specs = tensor_list_specs_after - - # align non-cat dimensions placements based on reshard cost - non_empty_specs = [spec for spec in tensor_list_specs if not is_empty(spec)] - mesh = non_empty_specs[0].mesh - ndim = non_empty_specs[0].ndim - new_placements: List[Placement] = [] - for mesh_dim in range(mesh.ndim): - # compute the minimum cost of resharding on this mesh_dim - if any( - spec.placements[mesh_dim] != non_empty_specs[0].placements[mesh_dim] - for spec in non_empty_specs - ): - # only reshard if there is a mismatch - need_reshard = True - reshard_cost = [] - for shard_dim in range(ndim): - # compute the cost of resharding on this shard_dim - cost: float = 0.0 - for spec in non_empty_specs: - global_shape = spec.shape - if global_shape[shard_dim] < mesh.size(mesh_dim): - # found one tensor where the shard_dim is smaller than - # mesh_dim. In this case, we cannot shard on this shard_dim, - # and hence set cost to infinity. - cost = +float("inf") - elif ( - is_tensor_dim_sharded(spec, dim=shard_dim) - or prod(global_shape) == 0 - ): - continue - else: - local_shape = compute_local_shape( - global_shape, spec.mesh, spec.placements - ) - cost += prod(local_shape) * spec.mesh.size(mesh_dim) - reshard_cost.append(cost) - best_dim = reshard_cost.index(min(reshard_cost)) - new_placements.append(Shard(best_dim)) - else: - # no mismatch, keep the original placement - new_placements.append(non_empty_specs[0].placements[mesh_dim]) - - if need_reshard: - tensor_list_specs_after = [] - for spec in tensor_list_specs: - if is_empty(spec): - tensor_list_specs_after.append(spec) - else: - tensor_list_specs_after.append( - DTensorSpec( - mesh=spec.mesh, - placements=tuple(new_placements), - tensor_meta=spec.tensor_meta, - ) - ) - - return OutputSharding( - output_spec=None, - schema_suggestions=[ - OpSchema( - op=op_schema.op, - args_schema=( - tuple(tensor_list_specs_after), - *op_schema.args_schema[1:], - ), - kwargs_schema=op_schema.kwargs_schema, - ), - ], - ) - else: - # at this point, the cat dim is not sharded, - return OutputSharding( - output_spec=DTensorSpec( - mesh=non_empty_specs[0].mesh, - placements=non_empty_specs[0].placements, - ), - ) - - @register_prop_rule( [ aten.split.Tensor,