Skip to content

Commit

Permalink
Support passing in compute kernel in table_wise sharding helper (pyto…
Browse files Browse the repository at this point in the history
…rch#2087)

Summary:
Pull Request resolved: pytorch#2087

Currently assumes compute kernel is QUANT when passing in device, which isn't very flexible. Makes it take in compute kernel explicitly

Reviewed By: gnahzg

Differential Revision: D58254737
  • Loading branch information
sarckk authored and facebook-github-bot committed Jun 11, 2024
1 parent b4b6d0b commit 30dcbb2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
4 changes: 3 additions & 1 deletion torchrec/distributed/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,15 @@ def _parameter_sharding_generator(
def table_wise(
rank: int,
device: Optional[str] = None,
compute_kernel: Optional[str] = None,
) -> ParameterShardingGenerator:
"""
Returns a generator of ParameterShardingPlan for `ShardingType::TABLE_WISE` for construct_module_sharding_plan.
Args:
rank (int): rank to place table when doing table wise
device (Optional[str]): device to place table when doing table_wise sharding
compute_kernel (Optional[str]): embedding compute kernel to use for the table
Example::
Expand Down Expand Up @@ -441,7 +443,7 @@ def _parameter_sharding_generator(
device_type,
sharder,
placements=([placement_helper(device, rank)] if device else None),
compute_kernel=(EmbeddingComputeKernel.QUANT.value if device else None),
compute_kernel=compute_kernel,
)

return _parameter_sharding_generator
Expand Down
19 changes: 15 additions & 4 deletions torchrec/distributed/tests/test_infer_hetero_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from hypothesis import given, settings
from torchrec import EmbeddingBagConfig, EmbeddingCollection, EmbeddingConfig
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner import ParameterConstraints
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
from torchrec.distributed.planner.types import Topology
Expand Down Expand Up @@ -71,12 +72,17 @@ def test_sharder_different_world_sizes_for_qec(self, sharding_device: str) -> No
weight_dtype=torch.qint8,
)
sharder = QuantEmbeddingCollectionSharder()
compute_kernel = EmbeddingComputeKernel.QUANT.value
module_plan = construct_module_sharding_plan(
non_sharded_model._module_kjt_input[0],
per_param_sharding={
"table_0": row_wise(([20, 10, 100], "cpu")),
"table_1": table_wise(rank=0, device="cuda"),
"table_2": table_wise(rank=1, device="cuda"),
"table_1": table_wise(
rank=0, device="cuda", compute_kernel=compute_kernel
),
"table_2": table_wise(
rank=1, device="cuda", compute_kernel=compute_kernel
),
},
# pyre-ignore
sharder=sharder,
Expand Down Expand Up @@ -165,12 +171,17 @@ def test_sharder_different_world_sizes_for_qebc(self) -> None:
weight_dtype=torch.qint8,
)
sharder = QuantEmbeddingBagCollectionSharder()
compute_kernel = EmbeddingComputeKernel.QUANT.value
module_plan = construct_module_sharding_plan(
non_sharded_model._module_kjt_input[0],
per_param_sharding={
"table_0": row_wise(([20, 10, 100], "cpu")),
"table_1": table_wise(rank=0, device="cuda"),
"table_2": table_wise(rank=1, device="cuda"),
"table_1": table_wise(
rank=0, device="cuda", compute_kernel=compute_kernel
),
"table_2": table_wise(
rank=1, device="cuda", compute_kernel=compute_kernel
),
},
# pyre-ignore
sharder=sharder,
Expand Down

0 comments on commit 30dcbb2

Please sign in to comment.