From 8ce46d557e62982337056a2bae641022312d6a2b Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 8 Feb 2024 07:49:14 -0800 Subject: [PATCH] Improve msg error thrown when no sharding option found (#1693) Summary: Add additional info in error msg and warning logs to make debugging easier for users. Reviewed By: henrylhtsang Differential Revision: D53539635 --- torchrec/distributed/planner/enumerators.py | 8 ++++++-- .../planner/tests/test_enumerators.py | 17 +++++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 90a506e6c..2338fc0d0 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -136,6 +136,7 @@ def enumerate( for compute_kernel in self._filter_compute_kernels( name, sharder.compute_kernels(sharding_type, self._compute_device), + sharding_type, ): ( shard_sizes, @@ -179,7 +180,9 @@ def enumerate( if not sharding_options_per_table: raise RuntimeError( "No available sharding type and compute kernel combination " - f"after applying user provided constraints for {name}" + f"after applying user provided constraints for {name}. " + f"Module: {sharder_key}, sharder: {sharder.__class__.__name__}, compute device: {self._compute_device}. " + f"To debug, search above for warning logs about no available sharding types/compute kernels for table: {name}" ) sharding_options.extend(sharding_options_per_table) @@ -222,6 +225,7 @@ def _filter_compute_kernels( self, name: str, allowed_compute_kernels: List[str], + sharding_type: str, ) -> List[str]: # for the log message only constrained_compute_kernels: List[str] = [ @@ -251,7 +255,7 @@ def _filter_compute_kernels( f"constraints for {name}. Constrained compute kernels: " f"{constrained_compute_kernels}, allowed compute kernels: " f"{allowed_compute_kernels}, filtered compute kernels: " - f"{filtered_compute_kernels}. Please check if the constrained " + f"{filtered_compute_kernels}, sharding type: {sharding_type}. Please check if the constrained " "compute kernels are too restrictive, if the sharder allows the " "compute kernels, or if non-strings are passed in." ) diff --git a/torchrec/distributed/planner/tests/test_enumerators.py b/torchrec/distributed/planner/tests/test_enumerators.py index 3736f8d62..0b61be814 100644 --- a/torchrec/distributed/planner/tests/test_enumerators.py +++ b/torchrec/distributed/planner/tests/test_enumerators.py @@ -746,8 +746,9 @@ def test_filter_compute_kernels_ebc(self) -> None: ) sharder = EmbeddingBagCollectionSharder() + sharding_type = ShardingType.ROW_WISE.value allowed_compute_kernels = enumerator._filter_compute_kernels( - "table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda") + "table_0", sharder.compute_kernels(sharding_type, "cuda"), sharding_type ) self.assertEqual( @@ -774,8 +775,9 @@ def test_filter_compute_kernels_mch_ebc(self) -> None: ) sharder = ManagedCollisionEmbeddingBagCollectionSharder() + sharding_type = ShardingType.ROW_WISE.value allowed_compute_kernels = enumerator._filter_compute_kernels( - "table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda") + "table_0", sharder.compute_kernels(sharding_type, "cuda"), sharding_type ) self.assertEqual( @@ -797,9 +799,10 @@ def test_filter_compute_kernels_mch_ebc_no_available(self) -> None: ) sharder = ManagedCollisionEmbeddingBagCollectionSharder() + sharding_type = ShardingType.ROW_WISE.value with self.assertWarns(Warning): allowed_compute_kernels = enumerator._filter_compute_kernels( - "table_0", sharder.compute_kernels(ShardingType.ROW_WISE.value, "cuda") + "table_0", sharder.compute_kernels(sharding_type, "cuda"), sharding_type ) self.assertEqual(allowed_compute_kernels, []) @@ -900,7 +903,9 @@ def test_throw_ex_no_sharding_option_for_table(self) -> None: with self.assertRaises(Exception) as context: _ = enumerator.enumerate(self.model, [sharder]) - self.assertTrue( - "No available sharding type and compute kernel combination after applying user provided constraints for table_1" - in str(context.exception) + self.assertEqual( + str(context.exception), + "No available sharding type and compute kernel combination after applying user provided constraints for table_1. " + "Module: torchrec.modules.embedding_modules.EmbeddingBagCollection, sharder: CWSharder, compute device: cuda. " + "To debug, search above for warning logs about no available sharding types/compute kernels for table: table_1", )