Skip to content

Commit

Permalink
Improve msg error thrown when no sharding option found (pytorch#1693)
Browse files Browse the repository at this point in the history
Summary:

Add additional info in error msg and warning logs to make debugging easier for users.

Reviewed By: henrylhtsang

Differential Revision: D53539635
  • Loading branch information
sarckk authored and facebook-github-bot committed Feb 8, 2024
1 parent b617cf8 commit 03f101b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
8 changes: 6 additions & 2 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def enumerate(
for compute_kernel in self._filter_compute_kernels(
name,
sharder.compute_kernels(sharding_type, self._compute_device),
sharding_type,
):
(
shard_sizes,
Expand Down Expand Up @@ -179,7 +180,9 @@ def enumerate(
if not sharding_options_per_table:
raise RuntimeError(
"No available sharding type and compute kernel combination "
f"after applying user provided constraints for {name}"
f"after applying user provided constraints for {name}. "
f"Module: {sharder_key}, sharder: {sharder.__class__.__name__}, compute device: {self._compute_device}. "
f"To debug, search above for warning logs about no available sharding types/compute kernels for table: {name}"
)

sharding_options.extend(sharding_options_per_table)
Expand Down Expand Up @@ -222,6 +225,7 @@ def _filter_compute_kernels(
self,
name: str,
allowed_compute_kernels: List[str],
sharding_type: str,
) -> List[str]:
# for the log message only
constrained_compute_kernels: List[str] = [
Expand Down Expand Up @@ -251,7 +255,7 @@ def _filter_compute_kernels(
f"constraints for {name}. Constrained compute kernels: "
f"{constrained_compute_kernels}, allowed compute kernels: "
f"{allowed_compute_kernels}, filtered compute kernels: "
f"{filtered_compute_kernels}. Please check if the constrained "
f"{filtered_compute_kernels}, sharding type: {sharding_type}. Please check if the constrained "
"compute kernels are too restrictive, if the sharder allows the "
"compute kernels, or if non-strings are passed in."
)
Expand Down
17 changes: 11 additions & 6 deletions torchrec/distributed/planner/tests/test_enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,9 @@ def test_filter_compute_kernels_ebc(self) -> None:
)

sharder = EmbeddingBagCollectionSharder()
sharding_type = ShardingType.ROW_WISE.value
allowed_compute_kernels = enumerator._filter_compute_kernels(
"table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda")
"table_0", sharder.compute_kernels(sharding_type, "cuda"), sharding_type
)

self.assertEqual(
Expand All @@ -774,8 +775,9 @@ def test_filter_compute_kernels_mch_ebc(self) -> None:
)

sharder = ManagedCollisionEmbeddingBagCollectionSharder()
sharding_type = ShardingType.ROW_WISE.value
allowed_compute_kernels = enumerator._filter_compute_kernels(
"table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda")
"table_0", sharder.compute_kernels(sharding_type, "cuda"), sharding_type
)

self.assertEqual(
Expand All @@ -797,9 +799,10 @@ def test_filter_compute_kernels_mch_ebc_no_available(self) -> None:
)

sharder = ManagedCollisionEmbeddingBagCollectionSharder()
sharding_type = ShardingType.ROW_WISE.value
with self.assertWarns(Warning):
allowed_compute_kernels = enumerator._filter_compute_kernels(
"table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda")
"table_0", sharder.compute_kernels(sharding_type, "cuda"), sharding_type
)

self.assertEqual(allowed_compute_kernels, [])
Expand Down Expand Up @@ -900,7 +903,9 @@ def test_throw_ex_no_sharding_option_for_table(self) -> None:
with self.assertRaises(Exception) as context:
_ = enumerator.enumerate(self.model, [sharder])

self.assertTrue(
"No available sharding type and compute kernel combination after applying user provided constraints for table_1"
in str(context.exception)
self.assertEqual(
str(context.exception),
"No available sharding type and compute kernel combination after applying user provided constraints for table_1. "
"Module: torchrec.modules.embedding_modules.EmbeddingBagCollection, sharder: CWSharder, compute device: cuda. "
"To debug, search above for warning logs about no available sharding types/compute kernels for table: table_1",
)

0 comments on commit 03f101b

Please sign in to comment.