From c8f769ac0c0749a2a1cd071b81523c6bcbaef317 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Thu, 1 Feb 2024 16:11:25 -0800 Subject: [PATCH] Add flushing and reset_cache_states to pre-hook and hook of state_dict and load_state_dict (#1674) Summary: The reason for doing both at the same time is to also enable the unit test. What this diff does: * call flushing before state_dict * call reset_cache_states after load_state_dict Problem previous is that when we call sharded_ebc.state_dict(), it won't recursively call lookup.state_dict(). So no flushing was called. Differential Revision: D53199744 --- .../distributed/batched_embedding_kernel.py | 12 ++++++ torchrec/distributed/embedding.py | 17 +++++++++ torchrec/distributed/embedding_lookup.py | 32 ++++++++++++++++ torchrec/distributed/embeddingbag.py | 17 +++++++++ .../test_utils/test_model_parallel_base.py | 37 +++++++++++++++++-- 5 files changed, 111 insertions(+), 4 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 08cdbc056..d7d1d0d01 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -529,6 +529,9 @@ def config(self) -> GroupedEmbeddingConfig: def flush(self) -> None: pass + def purge(self) -> None: + pass + def named_split_embedding_weights( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, torch.Tensor]]: @@ -649,6 +652,9 @@ def named_parameters( def flush(self) -> None: self._emb_module.flush() + def purge(self) -> None: + self._emb_module.reset_cache_states() + class BatchedDenseEmbedding(BaseBatchedEmbedding[torch.Tensor]): def __init__( @@ -810,6 +816,9 @@ def config(self) -> GroupedEmbeddingConfig: def flush(self) -> None: pass + def purge(self) -> None: + pass + def named_split_embedding_weights( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, torch.Tensor]]: @@ -935,6 +944,9 @@ def named_parameters( def flush(self) -> None: self._emb_module.flush() + def purge(self) -> None: + self._emb_module.reset_cache_states() + class BatchedDenseEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor]): def __init__( diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 322e5e96c..ada8169e8 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -421,6 +421,17 @@ def __init__( if module.device != torch.device("meta"): self.load_state_dict(module.state_dict()) + @staticmethod + def _pre_state_dict_hook( + self: "ShardedEmbeddingCollection", + prefix: str = "", + keep_vars: bool = False, + ) -> None: + for lookup in self._lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + lookup.flush() + @staticmethod def _pre_load_state_dict_hook( self: "ShardedEmbeddingCollection", @@ -475,6 +486,11 @@ def _pre_load_state_dict_hook( else torch.cat(local_shards, dim=0) ) + for lookup in self._lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + lookup.purge() + def _initialize_torch_state(self) -> None: # noqa """ This provides consistency between this class and the EmbeddingCollection's @@ -562,6 +578,7 @@ def post_state_dict_hook( destination_key = f"{prefix}embeddings.{table_name}.weight" destination[destination_key] = sharded_t + self.register_state_dict_pre_hook(self._pre_state_dict_hook) self._register_state_dict_hook(post_state_dict_hook) self._register_load_state_dict_pre_hook( self._pre_load_state_dict_hook, with_module=True diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 5eadb19ad..68d1176c7 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -276,6 +276,14 @@ def named_parameters_by_table( ) in embedding_kernel.named_parameters_by_table(): yield (table_name, tbe_slice) + def flush(self) -> None: + for emb_module in self._emb_modules: + emb_module.flush() + + def purge(self) -> None: + for emb_module in self._emb_modules: + emb_module.purge() + class CommOpGradientScaling(torch.autograd.Function): @staticmethod @@ -503,6 +511,14 @@ def named_parameters_by_table( ) in embedding_kernel.named_parameters_by_table(): yield (table_name, tbe_slice) + def flush(self) -> None: + for emb_module in self._emb_modules: + emb_module.flush() + + def purge(self) -> None: + for emb_module in self._emb_modules: + emb_module.purge() + class MetaInferGroupedEmbeddingsLookup( BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn @@ -627,6 +643,14 @@ def named_buffers( for emb_module in self._emb_modules: yield from emb_module.named_buffers(prefix, recurse) + def flush(self) -> None: + # not implemented + pass + + def purge(self) -> None: + # not implemented + pass + class MetaInferGroupedPooledEmbeddingsLookup( BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn @@ -771,6 +795,14 @@ def named_buffers( for emb_module in self._emb_modules: yield from emb_module.named_buffers(prefix, recurse) + def flush(self) -> None: + # not implemented + pass + + def purge(self) -> None: + # not implemented + pass + class InferGroupedLookupMixin(ABC): def forward( diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index bdfcf3c41..78e769cdb 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -515,6 +515,17 @@ def __init__( ]: self.load_state_dict(module.state_dict(), strict=False) + @staticmethod + def _pre_state_dict_hook( + self: "ShardedEmbeddingBagCollection", + prefix: str = "", + keep_vars: bool = False, + ) -> None: + for lookup in self._lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + lookup.flush() + @staticmethod def _pre_load_state_dict_hook( self: "ShardedEmbeddingBagCollection", @@ -571,6 +582,11 @@ def _pre_load_state_dict_hook( f"Unexpected state_dict key type {type(state_dict[key])} found for {key}" ) + for lookup in self._lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + lookup.purge() + def _initialize_torch_state(self) -> None: # noqa """ This provides consistency between this class and the EmbeddingBagCollection's @@ -657,6 +673,7 @@ def post_state_dict_hook( destination_key = f"{prefix}embedding_bags.{table_name}.weight" destination[destination_key] = sharded_t + self.register_state_dict_pre_hook(self._pre_state_dict_hook) self._register_state_dict_hook(post_state_dict_hook) self._register_load_state_dict_pre_hook( self._pre_load_state_dict_hook, with_module=True diff --git a/torchrec/distributed/test_utils/test_model_parallel_base.py b/torchrec/distributed/test_utils/test_model_parallel_base.py index 941e8f878..fe54e11d9 100644 --- a/torchrec/distributed/test_utils/test_model_parallel_base.py +++ b/torchrec/distributed/test_utils/test_model_parallel_base.py @@ -462,14 +462,43 @@ def test_meta_device_dmp_state_dict(self) -> None: # pyre-ignore[56] @given( - sharders=st.sampled_from( + sharder_type=st.sampled_from( [ - [EmbeddingBagCollectionSharder()], + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) - def test_load_state_dict(self, sharders: List[ModuleSharder[nn.Module]]) -> None: + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_load_state_dict( + self, sharder_type: str, sharding_type: str, kernel_type: str + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + ), + ), + ] models, batch = self._generate_dmps_and_batch(sharders) m1, m2 = models