From 3ed296659e4cc89ff6a639f04dd220d8994a2981 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Mon, 29 Jul 2024 23:02:57 -0700 Subject: [PATCH] support config changes from MVAI down to fbgemm Summary: integrate the new rocksdb config into mvai model authoring chain so that we could tune the model config and affect the rocksdb changes Differential Revision: D59785241 --- torchrec/distributed/types.py | 15 +++++++++++++++ torchrec/distributed/utils.py | 7 ++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 564a44c9c..b3a79abc5 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -26,6 +26,7 @@ Union, ) +from fbgemm_gpu.runtime_monitor import TBEStatsReporterConfig from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, CacheAlgorithm, @@ -588,16 +589,30 @@ class KeyValueParams: ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses and ports. Example: (("::1", 2000), ("::1", 2001), ("::1", 2002)). Reason for using tuple is we want it hashable. + ssd_rocksdb_write_buffer_size: Optional[int]: rocksdb write buffer size per tbe, + relavant to rocksdb compaction frequency + ssd_rocksdb_shards: Optional[int]: rocksdb shards number + gather_ssd_cache_stats: bool: whether enable ssd stats collection, std reporter and ods reporter + report_interval: int: report interval in train iteration if gather_ssd_cache_stats is enabled + ods_prefix: str: ods prefix for ods reporting """ ssd_storage_directory: Optional[str] = None ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None + ssd_rocksdb_write_buffer_size: Optional[int] = None + ssd_rocksdb_shards: Optional[int] = None + gather_ssd_cache_stats: Optional[bool] = None + stats_reporter_config: Optional[TBEStatsReporterConfig] = None def __hash__(self) -> int: return hash( ( self.ssd_storage_directory, self.ps_hosts, + self.ssd_rocksdb_write_buffer_size, + self.ssd_rocksdb_shards, + self.gather_ssd_cache_stats, + self.stats_reporter_config, ) ) diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 541dc2ff5..cb529450d 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -410,10 +410,15 @@ def add_params_from_parameter_sharding( parameter_sharding.compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value} and parameter_sharding.key_value_params is not None ): - key_value_params_dict = asdict(parameter_sharding.key_value_params) + kv_params = parameter_sharding.key_value_params + key_value_params_dict = asdict(kv_params) key_value_params_dict = { k: v for k, v in key_value_params_dict.items() if v is not None } + if kv_params.stats_reporter_config: + key_value_params_dict["stats_reporter_config"] = ( + kv_params.stats_reporter_config + ) fused_params.update(key_value_params_dict) # print warning if sharding_type is data_parallel or kernel is dense