diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 861e6e868..b74b726a0 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -900,7 +900,7 @@ def __init__( pg, ) self._param_per_table: Dict[str, nn.Parameter] = dict( - _gen_named_parameters_by_table_ssd( + _gen_named_parameters_by_table_ssd_pmt( emb_module=self._emb_module, table_name_to_count=self.table_name_to_count.copy(), config=self._config, @@ -933,11 +933,31 @@ def state_dict( destination: Optional[Dict[str, Any]] = None, prefix: str = "", keep_vars: bool = False, + no_snapshot: bool = True, ) -> Dict[str, Any]: - if destination is None: - destination = OrderedDict() + """ + Args: + no_snapshot (bool): the tensors in the returned dict are + PartiallyMaterializedTensors. this argument controls wether the + PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the + PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the + PartiallyMaterializedTensor has a RocksDB snapshot handle + """ + # in the case no_snapshot=False, a flush is required. we rely on the flush operation in + # ShardedEmbeddingBagCollection._pre_state_dict_hook() - return destination + emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) + for emb_table in emb_table_config_copy: + emb_table.local_metadata.placement._device = torch.device("cpu") + ret = get_state_dict( + emb_table_config_copy, + emb_tables, + self._pg, + destination, + prefix, + ) + return ret def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True @@ -950,14 +970,16 @@ def named_parameters( ): # hack before we support optimizer on sharded parameter level # can delete after PEA deprecation + # pyre-ignore [6] param = nn.Parameter(tensor) # pyre-ignore param._in_backward_optimizers = [EmptyFusedOptimizer()] yield name, param + # pyre-ignore [15] def named_split_embedding_weights( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: assert ( remove_duplicate ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" @@ -968,6 +990,21 @@ def named_split_embedding_weights( key = append_prefix(prefix, f"{config.name}.weight") yield key, tensor + def get_named_split_embedding_weights_snapshot( + self, prefix: str = "" + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + """ + Return an iterator over embedding tables, yielding both the table name as well as the embedding + table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid + RocksDB snapshot to support windowed access. + """ + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights(no_snapshot=False), + ): + key = append_prefix(prefix, f"{config.name}") + yield key, tensor + def flush(self) -> None: """ Flush the embeddings in cache back to SSD. Should be pretty expensive. @@ -982,11 +1019,11 @@ def purge(self) -> None: self.emb_module.lxu_cache_weights.zero_() self.emb_module.lxu_cache_state.fill_(-1) - def split_embedding_weights(self) -> List[torch.Tensor]: - """ - Return fake tensors. - """ - return [param.data for param in self._param_per_table.values()] + # pyre-ignore [15] + def split_embedding_weights( + self, no_snapshot: bool = True + ) -> List[PartiallyMaterializedTensor]: + return self.emb_module.split_embedding_weights(no_snapshot) class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule): diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index ff2e4449e..ad164dfeb 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -744,7 +744,7 @@ def _initialize_torch_state(self) -> None: # noqa self._model_parallel_name_to_shards_wrapper = OrderedDict() self._model_parallel_name_to_sharded_tensor = OrderedDict() self._model_parallel_name_to_dtensor = OrderedDict() - model_parallel_name_to_compute_kernel: Dict[str, str] = {} + self._model_parallel_name_to_compute_kernel: Dict[str, str] = {} for ( table_name, parameter_sharding, @@ -755,7 +755,7 @@ def _initialize_torch_state(self) -> None: # noqa self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict( [("local_tensors", []), ("local_offsets", [])] ) - model_parallel_name_to_compute_kernel[table_name] = ( + self._model_parallel_name_to_compute_kernel[table_name] = ( parameter_sharding.compute_kernel ) @@ -813,18 +813,17 @@ def _initialize_torch_state(self) -> None: # noqa "weight", nn.Parameter(torch.empty(0)) ) if ( - model_parallel_name_to_compute_kernel[table_name] + self._model_parallel_name_to_compute_kernel[table_name] != EmbeddingComputeKernel.DENSE.value ): self.embeddings[table_name].weight._in_backward_optimizers = [ EmptyFusedOptimizer() ] - if model_parallel_name_to_compute_kernel[table_name] in { - EmbeddingComputeKernel.KEY_VALUE.value - }: - continue if self._output_dtensor: + assert self._model_parallel_name_to_compute_kernel[table_name] not in { + EmbeddingComputeKernel.KEY_VALUE.value + } if shards_wrapper_map["local_tensors"]: self._model_parallel_name_to_dtensor[table_name] = ( DTensor.from_local( @@ -853,6 +852,8 @@ def _initialize_torch_state(self) -> None: # noqa ) else: # created ShardedTensors once in init, use in post_state_dict_hook + # note: at this point kvstore backed tensors don't own valid snapshots, so no read + # access is allowed on them. self._model_parallel_name_to_sharded_tensor[table_name] = ( ShardedTensor._init_from_local_shards( local_shards, @@ -861,6 +862,21 @@ def _initialize_torch_state(self) -> None: # noqa ) ) + def extract_sharded_kvtensors( + module: ShardedEmbeddingCollection, + ) -> OrderedDict[str, ShardedTensor]: + # retrieve all kvstore backed tensors + ret = OrderedDict() + for ( + table_name, + sharded_t, + ) in module._model_parallel_name_to_sharded_tensor.items(): + if self._model_parallel_name_to_compute_kernel[table_name] in { + EmbeddingComputeKernel.KEY_VALUE.value + }: + ret[table_name] = sharded_t + return ret + def post_state_dict_hook( module: ShardedEmbeddingCollection, destination: Dict[str, torch.Tensor], @@ -881,6 +897,29 @@ def post_state_dict_hook( destination_key = f"{prefix}embeddings.{table_name}.weight" destination[destination_key] = d_tensor + # kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid + # snapshot for read access. + sharded_kvtensors = extract_sharded_kvtensors(module) + sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors) + for lookup, sharding_type in zip( + module._lookups, module._sharding_type_to_sharding.keys() + ): + if sharding_type == ShardingType.DATA_PARALLEL.value: + # unwrap DDP + lookup = lookup.module + else: + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + for key, v in lookup.get_named_split_embedding_weights_snapshot(): + destination_key = f"{prefix}embeddings.{key}.weight" + assert key in sharded_kvtensors_copy + sharded_kvtensors_copy[key].local_shards()[0].tensor = v + for ( + table_name, + sharded_kvtensor, + ) in sharded_kvtensors_copy.items(): + destination_key = f"{prefix}embeddings.{table_name}.weight" + destination[destination_key] = sharded_kvtensor + self.register_state_dict_pre_hook(self._pre_state_dict_hook) self._register_state_dict_hook(post_state_dict_hook) self._register_load_state_dict_pre_hook( diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 82d5d68fe..12b4b4c90 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -348,6 +348,18 @@ def named_parameters_by_table( ) in embedding_kernel.named_parameters_by_table(): yield (table_name, tbe_slice) + def get_named_split_embedding_weights_snapshot( + self, + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: + """ + Return an iterator over embedding tables, yielding both the table name as well as the embedding + table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid + RocksDB snapshot to support windowed access. + """ + for emb_module in self._emb_modules: + if isinstance(emb_module, KeyValueEmbedding): + yield from emb_module.get_named_split_embedding_weights_snapshot() + def flush(self) -> None: for emb_module in self._emb_modules: emb_module.flush()