Skip to content

Commit

Permalink
Fail early if no sharding option found for table (pytorch#1657)
Browse files Browse the repository at this point in the history
Summary:

Currently we raise an exception if no sharding options are found for the first table. If a sharding option is found for the first table, but not for the second table, no exception is raised.

This causes the error to be [raised later when sharding the model](https://github.com/pytorch/torchrec/blob/77974f229ce7e229664fbe199e1308cc37a91d7f/torchrec/distributed/embeddingbag.py#L217-L218), which is harder to debug.

Differential Revision: D53044797
  • Loading branch information
sarckk authored and facebook-github-bot committed Jan 24, 2024
1 parent 77974f2 commit 011ae44
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def enumerate(
bounds_check_mode,
) = _extract_constraints_for_param(self._constraints, name)

sharding_options_per_table: List[ShardingOption] = []

for sharding_type in self._filter_sharding_types(
name, sharder.sharding_types(self._compute_device)
):
Expand All @@ -150,7 +152,7 @@ def enumerate(
elif isinstance(child_module, EmbeddingTowerCollection):
tower_index = _get_tower_index(name, child_module)
dependency = child_path + ".tower_" + str(tower_index)
sharding_options.append(
sharding_options_per_table.append(
ShardingOption(
name=name,
tensor=param,
Expand All @@ -172,12 +174,14 @@ def enumerate(
is_pooled=is_pooled,
)
)
if not sharding_options:
if not sharding_options_per_table:
raise RuntimeError(
"No available sharding type and compute kernel combination "
f"after applying user provided constraints for {name}"
)

sharding_options.extend(sharding_options_per_table)

self.populate_estimates(sharding_options)

return sharding_options
Expand Down
46 changes: 46 additions & 0 deletions torchrec/distributed/planner/tests/test_enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,3 +858,49 @@ def test_tower_collection_sharding(self) -> None:
def test_empty(self) -> None:
sharding_options = self.enumerator.enumerate(self.model, sharders=[])
self.assertFalse(sharding_options)

def test_throw_ex_no_sharding_option_for_table(self) -> None:
cw_constraint = ParameterConstraints(
sharding_types=[
ShardingType.COLUMN_WISE.value,
],
compute_kernels=[
EmbeddingComputeKernel.FUSED.value,
],
)

rw_constraint = ParameterConstraints(
sharding_types=[
ShardingType.TABLE_ROW_WISE.value,
],
compute_kernels=[
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
],
)

constraints = {
"table_0": cw_constraint,
"table_1": rw_constraint,
"table_2": cw_constraint,
"table_3": cw_constraint,
}

enumerator = EmbeddingEnumerator(
topology=Topology(
world_size=self.world_size,
compute_device=self.compute_device,
local_world_size=self.local_world_size,
),
batch_size=self.batch_size,
constraints=constraints,
)

sharder = cast(ModuleSharder[torch.nn.Module], CWSharder())

with self.assertRaises(Exception) as context:
_ = enumerator.enumerate(self.model, [sharder])

self.assertTrue(
"No available sharding type and compute kernel combination after applying user provided constraints for table_1"
in str(context.exception)
)

0 comments on commit 011ae44

Please sign in to comment.