Skip to content

Commit

Permalink
add ssd-emo checkpoint support for sEC (#2650)
Browse files Browse the repository at this point in the history
Summary:

as title, we only add support for sEBC for ssd-emo, this diff also add support for sEC

Differential Revision: D67183728
  • Loading branch information
duduyi2013 authored and facebook-github-bot committed Dec 23, 2024
1 parent 5f607ff commit 8b44a74
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 17 deletions.
57 changes: 47 additions & 10 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def __init__(
pg,
)
self._param_per_table: Dict[str, nn.Parameter] = dict(
_gen_named_parameters_by_table_ssd(
_gen_named_parameters_by_table_ssd_pmt(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
Expand Down Expand Up @@ -933,11 +933,31 @@ def state_dict(
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
no_snapshot: bool = True,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()
"""
Args:
no_snapshot (bool): the tensors in the returned dict are
PartiallyMaterializedTensors. this argument controls wether the
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
PartiallyMaterializedTensor has a RocksDB snapshot handle
"""
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

return destination
emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
ret = get_state_dict(
emb_table_config_copy,
emb_tables,
self._pg,
destination,
prefix,
)
return ret

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
Expand All @@ -950,14 +970,16 @@ def named_parameters(
):
# hack before we support optimizer on sharded parameter level
# can delete after PEA deprecation
# pyre-ignore [6]
param = nn.Parameter(tensor)
# pyre-ignore
param._in_backward_optimizers = [EmptyFusedOptimizer()]
yield name, param

# pyre-ignore [15]
def named_split_embedding_weights(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
assert (
remove_duplicate
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
Expand All @@ -968,6 +990,21 @@ def named_split_embedding_weights(
key = append_prefix(prefix, f"{config.name}.weight")
yield key, tensor

def get_named_split_embedding_weights_snapshot(
self, prefix: str = ""
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False),
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor

def flush(self) -> None:
"""
Flush the embeddings in cache back to SSD. Should be pretty expensive.
Expand All @@ -982,11 +1019,11 @@ def purge(self) -> None:
self.emb_module.lxu_cache_weights.zero_()
self.emb_module.lxu_cache_state.fill_(-1)

def split_embedding_weights(self) -> List[torch.Tensor]:
"""
Return fake tensors.
"""
return [param.data for param in self._param_per_table.values()]
# pyre-ignore [15]
def split_embedding_weights(
self, no_snapshot: bool = True
) -> List[PartiallyMaterializedTensor]:
return self.emb_module.split_embedding_weights(no_snapshot)


class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
Expand Down
53 changes: 46 additions & 7 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_shards_wrapper = OrderedDict()
self._model_parallel_name_to_sharded_tensor = OrderedDict()
self._model_parallel_name_to_dtensor = OrderedDict()
model_parallel_name_to_compute_kernel: Dict[str, str] = {}
self._model_parallel_name_to_compute_kernel: Dict[str, str] = {}
for (
table_name,
parameter_sharding,
Expand All @@ -755,7 +755,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
[("local_tensors", []), ("local_offsets", [])]
)
model_parallel_name_to_compute_kernel[table_name] = (
self._model_parallel_name_to_compute_kernel[table_name] = (
parameter_sharding.compute_kernel
)

Expand Down Expand Up @@ -813,18 +813,17 @@ def _initialize_torch_state(self) -> None: # noqa
"weight", nn.Parameter(torch.empty(0))
)
if (
model_parallel_name_to_compute_kernel[table_name]
self._model_parallel_name_to_compute_kernel[table_name]
!= EmbeddingComputeKernel.DENSE.value
):
self.embeddings[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]

if model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
continue
if self._output_dtensor:
assert self._model_parallel_name_to_compute_kernel[table_name] not in {
EmbeddingComputeKernel.KEY_VALUE.value
}
if shards_wrapper_map["local_tensors"]:
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
Expand Down Expand Up @@ -853,6 +852,8 @@ def _initialize_torch_state(self) -> None: # noqa
)
else:
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
Expand All @@ -861,6 +862,21 @@ def _initialize_torch_state(self) -> None: # noqa
)
)

def extract_sharded_kvtensors(
module: ShardedEmbeddingCollection,
) -> OrderedDict[str, ShardedTensor]:
# retrieve all kvstore backed tensors
ret = OrderedDict()
for (
table_name,
sharded_t,
) in module._model_parallel_name_to_sharded_tensor.items():
if self._model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
ret[table_name] = sharded_t
return ret

def post_state_dict_hook(
module: ShardedEmbeddingCollection,
destination: Dict[str, torch.Tensor],
Expand All @@ -881,6 +897,29 @@ def post_state_dict_hook(
destination_key = f"{prefix}embeddings.{table_name}.weight"
destination[destination_key] = d_tensor

# kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid
# snapshot for read access.
sharded_kvtensors = extract_sharded_kvtensors(module)
sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
for lookup, sharding_type in zip(
module._lookups, module._sharding_type_to_sharding.keys()
):
if sharding_type == ShardingType.DATA_PARALLEL.value:
# unwrap DDP
lookup = lookup.module
else:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for key, v in lookup.get_named_split_embedding_weights_snapshot():
destination_key = f"{prefix}embeddings.{key}.weight"
assert key in sharded_kvtensors_copy
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
for (
table_name,
sharded_kvtensor,
) in sharded_kvtensors_copy.items():
destination_key = f"{prefix}embeddings.{table_name}.weight"
destination[destination_key] = sharded_kvtensor

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
Expand Down
12 changes: 12 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ def named_parameters_by_table(
) in embedding_kernel.named_parameters_by_table():
yield (table_name, tbe_slice)

def get_named_split_embedding_weights_snapshot(
self,
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for emb_module in self._emb_modules:
if isinstance(emb_module, KeyValueEmbedding):
yield from emb_module.get_named_split_embedding_weights_snapshot()

def flush(self) -> None:
for emb_module in self._emb_modules:
emb_module.flush()
Expand Down

0 comments on commit 8b44a74

Please sign in to comment.