Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ssd-emo checkpoint support for sEC #2650

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def __init__(
pg,
)
self._param_per_table: Dict[str, nn.Parameter] = dict(
_gen_named_parameters_by_table_ssd(
_gen_named_parameters_by_table_ssd_pmt(
emb_module=self._emb_module,
table_name_to_count=self.table_name_to_count.copy(),
config=self._config,
Expand Down Expand Up @@ -933,11 +933,31 @@ def state_dict(
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
no_snapshot: bool = True,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()
"""
Args:
no_snapshot (bool): the tensors in the returned dict are
PartiallyMaterializedTensors. this argument controls wether the
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
PartiallyMaterializedTensor has a RocksDB snapshot handle
"""
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

return destination
emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
ret = get_state_dict(
emb_table_config_copy,
emb_tables,
self._pg,
destination,
prefix,
)
return ret

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
Expand All @@ -950,14 +970,16 @@ def named_parameters(
):
# hack before we support optimizer on sharded parameter level
# can delete after PEA deprecation
# pyre-ignore [6]
param = nn.Parameter(tensor)
# pyre-ignore
param._in_backward_optimizers = [EmptyFusedOptimizer()]
yield name, param

# pyre-ignore [15]
def named_split_embedding_weights(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
assert (
remove_duplicate
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
Expand All @@ -968,6 +990,21 @@ def named_split_embedding_weights(
key = append_prefix(prefix, f"{config.name}.weight")
yield key, tensor

def get_named_split_embedding_weights_snapshot(
self, prefix: str = ""
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False),
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor

def flush(self) -> None:
"""
Flush the embeddings in cache back to SSD. Should be pretty expensive.
Expand All @@ -982,11 +1019,11 @@ def purge(self) -> None:
self.emb_module.lxu_cache_weights.zero_()
self.emb_module.lxu_cache_state.fill_(-1)

def split_embedding_weights(self) -> List[torch.Tensor]:
"""
Return fake tensors.
"""
return [param.data for param in self._param_per_table.values()]
# pyre-ignore [15]
def split_embedding_weights(
self, no_snapshot: bool = True
) -> List[PartiallyMaterializedTensor]:
return self.emb_module.split_embedding_weights(no_snapshot)


class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
Expand Down
52 changes: 45 additions & 7 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_shards_wrapper = OrderedDict()
self._model_parallel_name_to_sharded_tensor = OrderedDict()
self._model_parallel_name_to_dtensor = OrderedDict()
model_parallel_name_to_compute_kernel: Dict[str, str] = {}
_model_parallel_name_to_compute_kernel: Dict[str, str] = {}
for (
table_name,
parameter_sharding,
Expand All @@ -755,7 +755,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
[("local_tensors", []), ("local_offsets", [])]
)
model_parallel_name_to_compute_kernel[table_name] = (
_model_parallel_name_to_compute_kernel[table_name] = (
parameter_sharding.compute_kernel
)

Expand Down Expand Up @@ -813,18 +813,17 @@ def _initialize_torch_state(self) -> None: # noqa
"weight", nn.Parameter(torch.empty(0))
)
if (
model_parallel_name_to_compute_kernel[table_name]
_model_parallel_name_to_compute_kernel[table_name]
!= EmbeddingComputeKernel.DENSE.value
):
self.embeddings[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]

if model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
continue
if self._output_dtensor:
assert _model_parallel_name_to_compute_kernel[table_name] not in {
EmbeddingComputeKernel.KEY_VALUE.value
}
if shards_wrapper_map["local_tensors"]:
self._model_parallel_name_to_dtensor[table_name] = (
DTensor.from_local(
Expand Down Expand Up @@ -853,6 +852,8 @@ def _initialize_torch_state(self) -> None: # noqa
)
else:
# created ShardedTensors once in init, use in post_state_dict_hook
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
# access is allowed on them.
self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards(
local_shards,
Expand All @@ -861,6 +862,21 @@ def _initialize_torch_state(self) -> None: # noqa
)
)

def extract_sharded_kvtensors(
module: ShardedEmbeddingCollection,
) -> OrderedDict[str, ShardedTensor]:
# retrieve all kvstore backed tensors
ret = OrderedDict()
for (
table_name,
sharded_t,
) in module._model_parallel_name_to_sharded_tensor.items():
if _model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
ret[table_name] = sharded_t
return ret

def post_state_dict_hook(
module: ShardedEmbeddingCollection,
destination: Dict[str, torch.Tensor],
Expand All @@ -881,6 +897,28 @@ def post_state_dict_hook(
destination_key = f"{prefix}embeddings.{table_name}.weight"
destination[destination_key] = d_tensor

# kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid
# snapshot for read access.
sharded_kvtensors = extract_sharded_kvtensors(module)
if len(sharded_kvtensors) == 0:
return

sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
for lookup, sharding_type in zip(
module._lookups, module._sharding_type_to_sharding.keys()
):
if sharding_type != ShardingType.DATA_PARALLEL.value:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for key, v in lookup.get_named_split_embedding_weights_snapshot():
assert key in sharded_kvtensors_copy
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
for (
table_name,
sharded_kvtensor,
) in sharded_kvtensors_copy.items():
destination_key = f"{prefix}embeddings.{table_name}.weight"
destination[destination_key] = sharded_kvtensor

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
Expand Down
12 changes: 12 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ def named_parameters_by_table(
) in embedding_kernel.named_parameters_by_table():
yield (table_name, tbe_slice)

def get_named_split_embedding_weights_snapshot(
self,
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for emb_module in self._emb_modules:
if isinstance(emb_module, KeyValueEmbedding):
yield from emb_module.get_named_split_embedding_weights_snapshot()

def flush(self) -> None:
for emb_module in self._emb_modules:
emb_module.flush()
Expand Down
19 changes: 9 additions & 10 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_sharded_tensor = OrderedDict()
self._model_parallel_name_to_dtensor = OrderedDict()

self._model_parallel_name_to_compute_kernel: Dict[str, str] = {}
_model_parallel_name_to_compute_kernel: Dict[str, str] = {}
for (
table_name,
parameter_sharding,
Expand All @@ -836,7 +836,7 @@ def _initialize_torch_state(self) -> None: # noqa
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
[("local_tensors", []), ("local_offsets", [])]
)
self._model_parallel_name_to_compute_kernel[table_name] = (
_model_parallel_name_to_compute_kernel[table_name] = (
parameter_sharding.compute_kernel
)

Expand Down Expand Up @@ -892,15 +892,15 @@ def _initialize_torch_state(self) -> None: # noqa
"weight", nn.Parameter(torch.empty(0))
)
if (
self._model_parallel_name_to_compute_kernel[table_name]
_model_parallel_name_to_compute_kernel[table_name]
!= EmbeddingComputeKernel.DENSE.value
):
self.embedding_bags[table_name].weight._in_backward_optimizers = [
EmptyFusedOptimizer()
]

if self._output_dtensor:
assert self._model_parallel_name_to_compute_kernel[table_name] not in {
assert _model_parallel_name_to_compute_kernel[table_name] not in {
EmbeddingComputeKernel.KEY_VALUE.value
}
if shards_wrapper_map["local_tensors"]:
Expand Down Expand Up @@ -954,7 +954,7 @@ def extract_sharded_kvtensors(
table_name,
sharded_t,
) in module._model_parallel_name_to_sharded_tensor.items():
if self._model_parallel_name_to_compute_kernel[table_name] in {
if _model_parallel_name_to_compute_kernel[table_name] in {
EmbeddingComputeKernel.KEY_VALUE.value
}:
ret[table_name] = sharded_t
Expand Down Expand Up @@ -983,15 +983,14 @@ def post_state_dict_hook(
# kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid
# snapshot for read access.
sharded_kvtensors = extract_sharded_kvtensors(module)
if len(sharded_kvtensors) == 0:
return

sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
for lookup, sharding in zip(module._lookups, module._embedding_shardings):
if isinstance(sharding, DpPooledEmbeddingSharding):
# unwrap DDP
lookup = lookup.module
else:
if not isinstance(sharding, DpPooledEmbeddingSharding):
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for key, v in lookup.get_named_split_embedding_weights_snapshot():
destination_key = f"{prefix}embedding_bags.{key}.weight"
assert key in sharded_kvtensors_copy
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
for (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,23 @@ def _copy_ssd_emb_modules(
"SSDEmbeddingBag or SSDEmbeddingBag."
)

weights = emb_module1.emb_module.debug_split_embedding_weights()
# need to set emb_module1 as well, since otherwise emb_module1 would
# produce a random debug_split_embedding_weights everytime
_load_split_embedding_weights(emb_module1, weights)
_load_split_embedding_weights(emb_module2, weights)
emb1_kv = dict(
emb_module1.get_named_split_embedding_weights_snapshot()
)
for (
k,
v,
) in emb_module2.get_named_split_embedding_weights_snapshot():
v1 = emb1_kv.get(k)
v1_full_tensor = v1.full_tensor()

# write value into ssd for both emb module for later comparison
v.wrapped.set_range(
0, 0, v1_full_tensor.size(0), v1_full_tensor
)
v1.wrapped.set_range(
0, 0, v1_full_tensor.size(0), v1_full_tensor
)

# purge after loading. This is needed, since we pass a batch
# through dmp when instantiating them.
Expand All @@ -141,10 +153,12 @@ def _copy_ssd_emb_modules(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.TABLE_COLUMN_WISE.value,
]
),
is_training=st.booleans(),
Expand Down Expand Up @@ -220,10 +234,12 @@ def test_ssd_load_state_dict(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.TABLE_COLUMN_WISE.value,
]
),
is_training=st.booleans(),
Expand Down Expand Up @@ -344,10 +360,12 @@ def test_ssd_tbe_numerical_accuracy(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.TABLE_COLUMN_WISE.value,
]
),
is_training=st.booleans(),
Expand Down Expand Up @@ -455,10 +473,12 @@ def test_ssd_fused_optimizer(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.TABLE_COLUMN_WISE.value,
]
),
is_training=st.booleans(),
Expand Down Expand Up @@ -682,7 +702,8 @@ def _copy_ssd_emb_modules(
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
# TODO: uncomment when ssd ckpt support cw sharding
# ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
]
),
Expand Down
Loading