Skip to content

Commit

Permalink
support config changes from MVAI down to fbgemm
Browse files Browse the repository at this point in the history
Summary: integrate the new rocksdb config into mvai model authoring chain so that we could tune the model config and affect the rocksdb changes

Differential Revision: D59785241
  • Loading branch information
duduyi2013 authored and facebook-github-bot committed Jul 30, 2024
1 parent 2771a90 commit 3ed2966
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
15 changes: 15 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Union,
)

from fbgemm_gpu.runtime_monitor import TBEStatsReporterConfig
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
CacheAlgorithm,
Expand Down Expand Up @@ -588,16 +589,30 @@ class KeyValueParams:
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
and ports. Example: (("::1", 2000), ("::1", 2001), ("::1", 2002)).
Reason for using tuple is we want it hashable.
ssd_rocksdb_write_buffer_size: Optional[int]: rocksdb write buffer size per tbe,
relavant to rocksdb compaction frequency
ssd_rocksdb_shards: Optional[int]: rocksdb shards number
gather_ssd_cache_stats: bool: whether enable ssd stats collection, std reporter and ods reporter
report_interval: int: report interval in train iteration if gather_ssd_cache_stats is enabled
ods_prefix: str: ods prefix for ods reporting
"""

ssd_storage_directory: Optional[str] = None
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
ssd_rocksdb_write_buffer_size: Optional[int] = None
ssd_rocksdb_shards: Optional[int] = None
gather_ssd_cache_stats: Optional[bool] = None
stats_reporter_config: Optional[TBEStatsReporterConfig] = None

def __hash__(self) -> int:
return hash(
(
self.ssd_storage_directory,
self.ps_hosts,
self.ssd_rocksdb_write_buffer_size,
self.ssd_rocksdb_shards,
self.gather_ssd_cache_stats,
self.stats_reporter_config,
)
)

Expand Down
7 changes: 6 additions & 1 deletion torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,15 @@ def add_params_from_parameter_sharding(
parameter_sharding.compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value}
and parameter_sharding.key_value_params is not None
):
key_value_params_dict = asdict(parameter_sharding.key_value_params)
kv_params = parameter_sharding.key_value_params
key_value_params_dict = asdict(kv_params)
key_value_params_dict = {
k: v for k, v in key_value_params_dict.items() if v is not None
}
if kv_params.stats_reporter_config:
key_value_params_dict["stats_reporter_config"] = (
kv_params.stats_reporter_config
)
fused_params.update(key_value_params_dict)

# print warning if sharding_type is data_parallel or kernel is dense
Expand Down

0 comments on commit 3ed2966

Please sign in to comment.