From 30dcbb25554848c4802e553d0c10281d053a9c36 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 10 Jun 2024 19:03:19 -0700 Subject: [PATCH] Support passing in compute kernel in table_wise sharding helper (#2087) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2087 Currently assumes compute kernel is QUANT when passing in device, which isn't very flexible. Makes it take in compute kernel explicitly Reviewed By: gnahzg Differential Revision: D58254737 --- torchrec/distributed/sharding_plan.py | 4 +++- .../tests/test_infer_hetero_shardings.py | 19 +++++++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 1cb602197..f5c245baa 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -397,6 +397,7 @@ def _parameter_sharding_generator( def table_wise( rank: int, device: Optional[str] = None, + compute_kernel: Optional[str] = None, ) -> ParameterShardingGenerator: """ Returns a generator of ParameterShardingPlan for `ShardingType::TABLE_WISE` for construct_module_sharding_plan. @@ -404,6 +405,7 @@ def table_wise( Args: rank (int): rank to place table when doing table wise device (Optional[str]): device to place table when doing table_wise sharding + compute_kernel (Optional[str]): embedding compute kernel to use for the table Example:: @@ -441,7 +443,7 @@ def _parameter_sharding_generator( device_type, sharder, placements=([placement_helper(device, rank)] if device else None), - compute_kernel=(EmbeddingComputeKernel.QUANT.value if device else None), + compute_kernel=compute_kernel, ) return _parameter_sharding_generator diff --git a/torchrec/distributed/tests/test_infer_hetero_shardings.py b/torchrec/distributed/tests/test_infer_hetero_shardings.py index e3a6780e0..b23ca9534 100755 --- a/torchrec/distributed/tests/test_infer_hetero_shardings.py +++ b/torchrec/distributed/tests/test_infer_hetero_shardings.py @@ -16,6 +16,7 @@ import torch from hypothesis import given, settings from torchrec import EmbeddingBagConfig, EmbeddingCollection, EmbeddingConfig +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner import ParameterConstraints from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner from torchrec.distributed.planner.types import Topology @@ -71,12 +72,17 @@ def test_sharder_different_world_sizes_for_qec(self, sharding_device: str) -> No weight_dtype=torch.qint8, ) sharder = QuantEmbeddingCollectionSharder() + compute_kernel = EmbeddingComputeKernel.QUANT.value module_plan = construct_module_sharding_plan( non_sharded_model._module_kjt_input[0], per_param_sharding={ "table_0": row_wise(([20, 10, 100], "cpu")), - "table_1": table_wise(rank=0, device="cuda"), - "table_2": table_wise(rank=1, device="cuda"), + "table_1": table_wise( + rank=0, device="cuda", compute_kernel=compute_kernel + ), + "table_2": table_wise( + rank=1, device="cuda", compute_kernel=compute_kernel + ), }, # pyre-ignore sharder=sharder, @@ -165,12 +171,17 @@ def test_sharder_different_world_sizes_for_qebc(self) -> None: weight_dtype=torch.qint8, ) sharder = QuantEmbeddingBagCollectionSharder() + compute_kernel = EmbeddingComputeKernel.QUANT.value module_plan = construct_module_sharding_plan( non_sharded_model._module_kjt_input[0], per_param_sharding={ "table_0": row_wise(([20, 10, 100], "cpu")), - "table_1": table_wise(rank=0, device="cuda"), - "table_2": table_wise(rank=1, device="cuda"), + "table_1": table_wise( + rank=0, device="cuda", compute_kernel=compute_kernel + ), + "table_2": table_wise( + rank=1, device="cuda", compute_kernel=compute_kernel + ), }, # pyre-ignore sharder=sharder,