Skip to content

Commit

Permalink
Improve error handling in enumerator
Browse files Browse the repository at this point in the history
Summary: Improve warning messaging handling for enumerator.

Differential Revision: D52545226
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 10, 2024
1 parent 9bc6103 commit 118b507
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,38 +186,51 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
for estimator in self._estimators:
estimator.estimate(sharding_options, self._sharder_map)

def _filter_sharding_types(self, name: str, sharding_types: List[str]) -> List[str]:
def _filter_sharding_types(
self, name: str, allowed_sharding_types: List[str]
) -> List[str]:
if not self._constraints or not self._constraints.get(name):
return sharding_types
return allowed_sharding_types
constraints: ParameterConstraints = self._constraints[name]
if not constraints.sharding_types:
return sharding_types
return allowed_sharding_types
constrained_sharding_types: List[str] = constraints.sharding_types

sharding_types = list(set(constrained_sharding_types) & set(sharding_types))
filtered_sharding_types = list(
set(constrained_sharding_types) & set(allowed_sharding_types)
)

if not sharding_types:
if not filtered_sharding_types:
logger.warn(
f"No available sharding types after applying user provided constraints for {name}"
"No available sharding types after applying user provided "
f"constraints for {name}. Constrained sharding types: "
f"{constrained_sharding_types}, allowed sharding types: "
f"{allowed_sharding_types}, filtered sharding types: "
f"{filtered_sharding_types}. Please check if the constrained "
"sharding types are too restrictive, if the sharder allows the "
"sharding types, or if non-strings are passed in."
)
return sharding_types
return filtered_sharding_types

def _filter_compute_kernels(
self,
name: str,
compute_kernels: List[str],
allowed_compute_kernels: List[str],
) -> List[str]:

# for the log message only
constrained_compute_kernels: List[str] = [
compute_kernel.value for compute_kernel in EmbeddingComputeKernel
]
if not self._constraints or not self._constraints.get(name):
filtered_compute_kernels = compute_kernels
filtered_compute_kernels = allowed_compute_kernels
else:
constraints: ParameterConstraints = self._constraints[name]
if not constraints.compute_kernels:
filtered_compute_kernels = compute_kernels
filtered_compute_kernels = allowed_compute_kernels
else:
constrained_compute_kernels: List[str] = constraints.compute_kernels
constrained_compute_kernels = constraints.compute_kernels
filtered_compute_kernels = list(
set(constrained_compute_kernels) & set(compute_kernels)
set(constrained_compute_kernels) & set(allowed_compute_kernels)
)

if EmbeddingComputeKernel.DENSE.value in filtered_compute_kernels:
Expand All @@ -228,7 +241,13 @@ def _filter_compute_kernels(

if not filtered_compute_kernels:
logger.warn(
f"No available compute kernels after applying user provided constraints for {name}"
"No available compute kernels after applying user provided "
f"constraints for {name}. Constrained compute kernels: "
f"{constrained_compute_kernels}, allowed compute kernels: "
f"{allowed_compute_kernels}, filtered compute kernels: "
f"{filtered_compute_kernels}. Please check if the constrained "
"compute kernels are too restrictive, if the sharder allows the "
"compute kernels, or if non-strings are passed in."
)
return filtered_compute_kernels

Expand Down

0 comments on commit 118b507

Please sign in to comment.