Skip to content

Commit

Permalink
Refactor passing over cache params (pytorch#2155)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2155

Refactor the passing over cache params from dataclass to fused_params dict a bit.

Motivation:
I am trying to add KeyValueParams.

Differential Revision: D58886177
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jun 21, 2024
1 parent a117ae2 commit 6d5970e
Showing 1 changed file with 39 additions and 37 deletions.
76 changes: 39 additions & 37 deletions torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys

from collections import OrderedDict
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union

import torch
Expand Down Expand Up @@ -377,45 +378,46 @@ def add_params_from_parameter_sharding(
# update fused_params using params from parameter_sharding
# this will take precidence over the fused_params provided from sharders
if parameter_sharding.cache_params is not None:
cache_params = parameter_sharding.cache_params
if cache_params.algorithm is not None:
fused_params["cache_algorithm"] = cache_params.algorithm
if cache_params.load_factor is not None:
fused_params["cache_load_factor"] = cache_params.load_factor
if cache_params.reserved_memory is not None:
fused_params["cache_reserved_memory"] = cache_params.reserved_memory
if cache_params.precision is not None:
fused_params["cache_precision"] = cache_params.precision
if cache_params.prefetch_pipeline is not None:
fused_params["prefetch_pipeline"] = cache_params.prefetch_pipeline
if cache_params.multipass_prefetch_config is not None:
fused_params["multipass_prefetch_config"] = (
cache_params.multipass_prefetch_config
)

if parameter_sharding.enforce_hbm is not None:
fused_params["enforce_hbm"] = parameter_sharding.enforce_hbm

if parameter_sharding.stochastic_rounding is not None:
fused_params["stochastic_rounding"] = parameter_sharding.stochastic_rounding

if parameter_sharding.bounds_check_mode is not None:
fused_params["bounds_check_mode"] = parameter_sharding.bounds_check_mode

if parameter_sharding.output_dtype is not None:
fused_params["output_dtype"] = parameter_sharding.output_dtype
cache_params_dict = asdict(parameter_sharding.cache_params)

def _add_cache_prefix(key: str) -> str:
if key in {"algorithm", "load_factor", "reserved_memory", "precision"}:
return f"cache_{key}"
return key

cache_params_dict = {
_add_cache_prefix(k): v
for k, v in cache_params_dict.items()
if v is not None and k not in {"stats"}
}
fused_params.update(cache_params_dict)

parameter_sharding_dict = asdict(parameter_sharding)
params_to_fused_tbe: Set[str] = {
"enforce_hbm",
"stochastic_rounding",
"bounds_check_mode",
"output_dtype",
}
parameter_sharding_dict = {
k: v
for k, v in parameter_sharding_dict.items()
if v is not None and k in params_to_fused_tbe
}
fused_params.update(parameter_sharding_dict)

# print warning if sharding_type is data_parallel or kernel is dense
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
logger.warning(
f"Sharding Type is {parameter_sharding.sharding_type}, "
"caching params will be ignored"
)
elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value:
logger.warning(
f"Compute Kernel is {parameter_sharding.compute_kernel}, "
"caching params will be ignored"
)
if parameter_sharding.cache_params is not None:
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
logger.warning(
f"Sharding Type is {parameter_sharding.sharding_type}, "
"caching params will be ignored"
)
elif parameter_sharding.compute_kernel == EmbeddingComputeKernel.DENSE.value:
logger.warning(
f"Compute Kernel is {parameter_sharding.compute_kernel}, "
"caching params will be ignored"
)

return fused_params

Expand Down

0 comments on commit 6d5970e

Please sign in to comment.