Skip to content

Commit

Permalink
add ssd-emo checkpoint support for sEC (#2650)
Browse files Browse the repository at this point in the history
Summary:

as title, we only add support for sEBC for ssd-emo, this diff also add support for sEC

Reviewed By: sarckk, jiayulu

Differential Revision: D67183728
  • Loading branch information
duduyi2013 authored and facebook-github-bot committed Dec 28, 2024
1 parent 75307b1 commit 745a479
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 41 deletions.
57 changes: 47 additions & 10 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def __init__(
pg,
)
self._param_per_table: Dict[str, nn.Parameter] = dict(
_gen_named_parameters_by_table_ssd(
_gen_named_parameters_by_table_ssd_pmt(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
Expand Down Expand Up @@ -933,11 +933,31 @@ def state_dict(
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
no_snapshot: bool = True,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()
"""
Args:
no_snapshot (bool): the tensors in the returned dict are
PartiallyMaterializedTensors. this argument controls wether the
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
PartiallyMaterializedTensor has a RocksDB snapshot handle
"""
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

return destination
emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
ret = get_state_dict(
emb_table_config_copy,
emb_tables,
self._pg,
destination,
prefix,
)
return ret

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
Expand All @@ -950,14 +970,16 @@ def named_parameters(
):
# hack before we support optimizer on sharded parameter level
# can delete after PEA deprecation
# pyre-ignore [6]
param = nn.Parameter(tensor)
# pyre-ignore
param._in_backward_optimizers = [EmptyFusedOptimizer()]
yield name, param

# pyre-ignore [15]
def named_split_embedding_weights(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
assert (
remove_duplicate
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
Expand All @@ -968,6 +990,21 @@ def named_split_embedding_weights(
key = append_prefix(prefix, f"{config.name}.weight")
yield key, tensor

def get_named_split_embedding_weights_snapshot(
self, prefix: str = ""
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False),
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor

def flush(self) -> None:
"""
Flush the embeddings in cache back to SSD. Should be pretty expensive.
Expand All @@ -982,11 +1019,11 @@ def purge(self) -> None:
self.emb_module.lxu_cache_weights.zero_()
self.emb_module.lxu_cache_state.fill_(-1)

def split_embedding_weights(self) -> List[torch.Tensor]:
"""
Return fake tensors.
"""
return [param.data for param in self._param_per_table.values()]
# pyre-ignore [15]
def split_embedding_weights(
self, no_snapshot: bool = True
) -> List[PartiallyMaterializedTensor]:
return self.emb_module.split_embedding_weights(no_snapshot)


class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
Expand Down
52 changes: 45 additions & 7 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_shards_wrapper = OrderedDict()
self._model_parallel_name_to_sharded_tensor = OrderedDict()
self._model_parallel_name_to_dtensor = OrderedDict()
model_parallel_name_to_compute_kernel: Dict[str, str] = {}
_model_parallel_name_to_compute_kernel: Dict[str, str] = {}
for (
table_name,
parameter_sharding,
Expand All @@ -755,7 +755,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
[("local_tensors", []), ("local_offsets", [])]
)
model_parallel_name_to_compute_kernel[table_name] = (
_model_parallel_name_to_compute_kernel[table_name] = (
parameter_sharding.compute_kernel
)

Expand Down Expand Up @@ -813,18 +813,17 @@ def _initialize_torch_state(self) -> None: # noqa
"weight", nn.Parameter(torch.empty(0))
)
if (
model_parallel_name_to_compute_kernel[table_name]
_model_parallel_name_to_compute_kernel[table_name]
!= EmbeddingComputeKernel.DENSE.value
):
self.embeddings[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]

if model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
continue
if self._output_dtensor:
assert _model_parallel_name_to_compute_kernel[table_name] not in {
EmbeddingComputeKernel.KEY_VALUE.value
}
if shards_wrapper_map["local_tensors"]:
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
Expand Down Expand Up @@ -853,6 +852,8 @@ def _initialize_torch_state(self) -> None: # noqa
)
else:
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
Expand All @@ -861,6 +862,21 @@ def _initialize_torch_state(self) -> None: # noqa
)
)

def extract_sharded_kvtensors(
module: ShardedEmbeddingCollection,
) -> OrderedDict[str, ShardedTensor]:
# retrieve all kvstore backed tensors
ret = OrderedDict()
for (
table_name,
sharded_t,
) in module._model_parallel_name_to_sharded_tensor.items():
if _model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
ret[table_name] = sharded_t
return ret

def post_state_dict_hook(
module: ShardedEmbeddingCollection,
destination: Dict[str, torch.Tensor],
Expand All @@ -881,6 +897,28 @@ def post_state_dict_hook(
destination_key = f"{prefix}embeddings.{table_name}.weight"
destination[destination_key] = d_tensor

# kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid
# snapshot for read access.
sharded_kvtensors = extract_sharded_kvtensors(module)
if len(sharded_kvtensors) == 0:
return

sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
for lookup, sharding_type in zip(
module._lookups, module._sharding_type_to_sharding.keys()
):
if sharding_type != ShardingType.DATA_PARALLEL.value:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for key, v in lookup.get_named_split_embedding_weights_snapshot():
assert key in sharded_kvtensors_copy
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
for (
table_name,
sharded_kvtensor,
) in sharded_kvtensors_copy.items():
destination_key = f"{prefix}embeddings.{table_name}.weight"
destination[destination_key] = sharded_kvtensor

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
Expand Down
12 changes: 12 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ def named_parameters_by_table(
) in embedding_kernel.named_parameters_by_table():
yield (table_name, tbe_slice)

def get_named_split_embedding_weights_snapshot(
self,
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for emb_module in self._emb_modules:
if isinstance(emb_module, KeyValueEmbedding):
yield from emb_module.get_named_split_embedding_weights_snapshot()

def flush(self) -> None:
for emb_module in self._emb_modules:
emb_module.flush()
Expand Down
19 changes: 9 additions & 10 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_sharded_tensor = OrderedDict()
self._model_parallel_name_to_dtensor = OrderedDict()

self._model_parallel_name_to_compute_kernel: Dict[str, str] = {}
_model_parallel_name_to_compute_kernel: Dict[str, str] = {}
for (
table_name,
parameter_sharding,
Expand All @@ -836,7 +836,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
[("local_tensors", []), ("local_offsets", [])]
)
self._model_parallel_name_to_compute_kernel[table_name] = (
_model_parallel_name_to_compute_kernel[table_name] = (
parameter_sharding.compute_kernel
)

Expand Down Expand Up @@ -892,15 +892,15 @@ def _initialize_torch_state(self) -> None: # noqa
"weight", nn.Parameter(torch.empty(0))
)
if (
self._model_parallel_name_to_compute_kernel[table_name]
_model_parallel_name_to_compute_kernel[table_name]
!= EmbeddingComputeKernel.DENSE.value
):
self.embedding_bags[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]

if self._output_dtensor:
assert self._model_parallel_name_to_compute_kernel[table_name] not in {
assert _model_parallel_name_to_compute_kernel[table_name] not in {
EmbeddingComputeKernel.KEY_VALUE.value
}
if shards_wrapper_map["local_tensors"]:
Expand Down Expand Up @@ -954,7 +954,7 @@ def extract_sharded_kvtensors(
table_name,
sharded_t,
) in module._model_parallel_name_to_sharded_tensor.items():
if self._model_parallel_name_to_compute_kernel[table_name] in {
if _model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
ret[table_name] = sharded_t
Expand Down Expand Up @@ -983,15 +983,14 @@ def post_state_dict_hook(
# kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid
# snapshot for read access.
sharded_kvtensors = extract_sharded_kvtensors(module)
if len(sharded_kvtensors) == 0:
return

sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
for lookup, sharding in zip(module._lookups, module._embedding_shardings):
if isinstance(sharding, DpPooledEmbeddingSharding):
# unwrap DDP
lookup = lookup.module
else:
if not isinstance(sharding, DpPooledEmbeddingSharding):
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for key, v in lookup.get_named_split_embedding_weights_snapshot():
destination_key = f"{prefix}embedding_bags.{key}.weight"
assert key in sharded_kvtensors_copy
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
for (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,23 @@ def _copy_ssd_emb_modules(
"SSDEmbeddingBag or SSDEmbeddingBag."
)

weights = emb_module1.emb_module.debug_split_embedding_weights()
# need to set emb_module1 as well, since otherwise emb_module1 would
# produce a random debug_split_embedding_weights everytime
_load_split_embedding_weights(emb_module1, weights)
_load_split_embedding_weights(emb_module2, weights)
emb1_kv = dict(
emb_module1.get_named_split_embedding_weights_snapshot()
)
for (
k,
v,
) in emb_module2.get_named_split_embedding_weights_snapshot():
v1 = emb1_kv.get(k)
v1_full_tensor = v1.full_tensor()

# write value into ssd for both emb module for later comparison
v.wrapped.set_range(
0, 0, v1_full_tensor.size(0), v1_full_tensor
)
v1.wrapped.set_range(
0, 0, v1_full_tensor.size(0), v1_full_tensor
)

# purge after loading. This is needed, since we pass a batch
# through dmp when instantiating them.
Expand All @@ -141,10 +153,12 @@ def _copy_ssd_emb_modules(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.TABLE_COLUMN_WISE.value,
]
),
is_training=st.booleans(),
Expand Down Expand Up @@ -220,10 +234,12 @@ def test_ssd_load_state_dict(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.TABLE_COLUMN_WISE.value,
]
),
is_training=st.booleans(),
Expand Down Expand Up @@ -344,10 +360,12 @@ def test_ssd_tbe_numerical_accuracy(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.TABLE_COLUMN_WISE.value,
]
),
is_training=st.booleans(),
Expand Down Expand Up @@ -455,10 +473,12 @@ def test_ssd_fused_optimizer(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.TABLE_COLUMN_WISE.value,
]
),
is_training=st.booleans(),
Expand Down Expand Up @@ -682,7 +702,8 @@ def _copy_ssd_emb_modules(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
]
),
Expand Down

0 comments on commit 745a479

Please sign in to comment.