diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 26efb1263..b7a6bd16c 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -13,6 +13,7 @@ import sys from collections import OrderedDict +from dataclasses import asdict from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union import torch @@ -377,45 +378,46 @@ def add_params_from_parameter_sharding( # update fused_params using params from parameter_sharding # this will take precidence over the fused_params provided from sharders if parameter_sharding.cache_params is not None: - cache_params = parameter_sharding.cache_params - if cache_params.algorithm is not None: - fused_params["cache_algorithm"] = cache_params.algorithm - if cache_params.load_factor is not None: - fused_params["cache_load_factor"] = cache_params.load_factor - if cache_params.reserved_memory is not None: - fused_params["cache_reserved_memory"] = cache_params.reserved_memory - if cache_params.precision is not None: - fused_params["cache_precision"] = cache_params.precision - if cache_params.prefetch_pipeline is not None: - fused_params["prefetch_pipeline"] = cache_params.prefetch_pipeline - if cache_params.multipass_prefetch_config is not None: - fused_params["multipass_prefetch_config"] = ( - cache_params.multipass_prefetch_config - ) - - if parameter_sharding.enforce_hbm is not None: - fused_params["enforce_hbm"] = parameter_sharding.enforce_hbm - - if parameter_sharding.stochastic_rounding is not None: - fused_params["stochastic_rounding"] = parameter_sharding.stochastic_rounding - - if parameter_sharding.bounds_check_mode is not None: - fused_params["bounds_check_mode"] = parameter_sharding.bounds_check_mode - - if parameter_sharding.output_dtype is not None: - fused_params["output_dtype"] = parameter_sharding.output_dtype + cache_params_dict = asdict(parameter_sharding.cache_params) + + def _add_cache_prefix(key: str) -> str: + if key in {"algorithm", "load_factor", "reserved_memory", "precision"}: + return f"cache_{key}" + return key + + cache_params_dict = { + _add_cache_prefix(k): v + for k, v in cache_params_dict.items() + if v is not None and k not in {"stats"} + } + fused_params.update(cache_params_dict) + + parameter_sharding_dict = asdict(parameter_sharding) + params_to_fused_tbe: Set[str] = { + "enforce_hbm", + "stochastic_rounding", + "bounds_check_mode", + "output_dtype", + } + parameter_sharding_dict = { + k: v + for k, v in parameter_sharding_dict.items() + if v is not None and k in params_to_fused_tbe + } + fused_params.update(parameter_sharding_dict) # print warning if sharding_type is data_parallel or kernel is dense - if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: - logger.warning( - f"Sharding Type is {parameter_sharding.sharding_type}, " - "caching params will be ignored" - ) - elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value: - logger.warning( - f"Compute Kernel is {parameter_sharding.compute_kernel}, " - "caching params will be ignored" - ) + if parameter_sharding.cache_params is not None: + if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: + logger.warning( + f"Sharding Type is {parameter_sharding.sharding_type}, " + "caching params will be ignored" + ) + elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value: + logger.warning( + f"Compute Kernel is {parameter_sharding.compute_kernel}, " + "caching params will be ignored" + ) return fused_params