Skip to content

Commit

Permalink
Add ability to specify pipelineable preproc modules to ignore during …
Browse files Browse the repository at this point in the history
…SDD model rewrite (pytorch#2149)

Summary:
Pull Request resolved: pytorch#2149

Make torchrec automatically pipeline any modules that don't have trainable params during sparse data dist pipelining.

tldr; with some traversal logic changes, TorchRec sparse data dist pipeline can support arbitrary input transformations at input dist stage as long as they are composed of either nn.Module calls or currently supported ops (mainly getattr and getitem)

Differential Revision: D57944338
  • Loading branch information
sarckk authored and facebook-github-bot committed Jun 24, 2024
1 parent 3eff5ce commit 54047a5
Show file tree
Hide file tree
Showing 6 changed files with 1,153 additions and 43 deletions.
270 changes: 266 additions & 4 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import copy
import random
from dataclasses import dataclass
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -239,10 +240,16 @@ def _validate_pooling_factor(
else None
)

global_float = torch.rand(
(batch_size * world_size, num_float_features), device=device
)
global_label = torch.rand(batch_size * world_size, device=device)
if randomize_indices:
global_float = torch.rand(
(batch_size * world_size, num_float_features), device=device
)
global_label = torch.rand(batch_size * world_size, device=device)
else:
global_float = torch.zeros(
(batch_size * world_size, num_float_features), device=device
)
global_label = torch.zeros(batch_size * world_size, device=device)

# Split global batch into local batches.
local_inputs = []
Expand Down Expand Up @@ -939,6 +946,7 @@ def __init__(
max_feature_lengths_list: Optional[List[Dict[str, int]]] = None,
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
over_arch_clazz: Type[nn.Module] = TestOverArch,
preproc_module: Optional[nn.Module] = None,
) -> None:
super().__init__(
tables=cast(List[BaseEmbeddingConfig], tables),
Expand All @@ -960,13 +968,22 @@ def __init__(
embedding_names = (
list(embedding_groups.values())[0] if embedding_groups else None
)
self._embedding_names: List[str] = (
embedding_names
if embedding_names
else [feature for table in tables for feature in table.feature_names]
)
self._weighted_features: List[str] = [
feature for table in weighted_tables for feature in table.feature_names
]
self.over: nn.Module = over_arch_clazz(
tables, weighted_tables, embedding_names, dense_device
)
self.register_buffer(
"dummy_ones",
torch.ones(1, device=dense_device),
)
self.preproc_module = preproc_module

def sparse_forward(self, input: ModelInput) -> KeyedTensor:
return self.sparse(
Expand All @@ -993,6 +1010,8 @@ def forward(
self,
input: ModelInput,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.preproc_module:
input = self.preproc_module(input)
return self.dense_forward(input, self.sparse_forward(input))


Expand Down Expand Up @@ -1409,3 +1428,246 @@ def _post_ebc_test_wrap_function(kt: KeyedTensor) -> KeyedTensor:
continue

return kt


class TestPreprocNonWeighted(nn.Module):
"""
Basic module for testing
Args: None
Examples:
>>> TestPreprocNonWeighted()
Returns:
List[KeyedJaggedTensor]
"""

def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
"""
Selects 3 features from a specific KJT
"""
# split
jt_0 = kjt["feature_0"]
jt_1 = kjt["feature_1"]
jt_2 = kjt["feature_2"]

# merge only features 0,1,2, removing feature 3
return [
KeyedJaggedTensor.from_jt_dict(
{
"feature_0": jt_0,
"feature_1": jt_1,
"feature_2": jt_2,
}
)
]


class TestPreprocWeighted(nn.Module):
"""
Basic module for testing
Args: None
Examples:
>>> TestPreprocWeighted()
Returns:
List[KeyedJaggedTensor]
"""

def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
"""
Selects 1 feature from specific weighted KJT
"""

# split
jt_0 = kjt["weighted_feature_0"]

# keep only weighted_feature_0
return [
KeyedJaggedTensor.from_jt_dict(
{
"weighted_feature_0": jt_0,
}
)
]


class TestModelWithPreproc(nn.Module):
"""
Basic module with up to 3 preproc modules:
- preproc on idlist_features for non-weighted EBC
- preproc on idscore_features for weighted EBC
- optional preproc on model input shared by both EBCs
Args:
tables,
weighted_tables,
device,
preproc_module,
num_float_features,
run_preproc_inline,
Example:
>>> TestModelWithPreproc(tables, weighted_tables, device)
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""

def __init__(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
device: torch.device,
preproc_module: Optional[nn.Module] = None,
num_float_features: int = 10,
run_preproc_inline: bool = False,
) -> None:
super().__init__()
self.dense = TestDenseArch(num_float_features, device)

self.ebc: EmbeddingBagCollection = EmbeddingBagCollection(
tables=tables,
device=device,
)
self.weighted_ebc = EmbeddingBagCollection(
tables=weighted_tables,
is_weighted=True,
device=device,
)
self.preproc_nonweighted = TestPreprocNonWeighted()
self.preproc_weighted = TestPreprocWeighted()
self._preproc_module = preproc_module
self._run_preproc_inline = run_preproc_inline

def forward(
self,
input: ModelInput,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Runs preprco for EBC and weighted EBC, optionally runs preproc for input
Args:
input
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""
modified_input = input

if self._preproc_module is not None:
modified_input = self._preproc_module(modified_input)
elif self._run_preproc_inline:
modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
modified_input.idlist_features.keys(),
modified_input.idlist_features.values(),
modified_input.idlist_features.lengths(),
)

modified_idlist_features = self.preproc_nonweighted(
modified_input.idlist_features
)
modified_idscore_features = self.preproc_weighted(
modified_input.idscore_features
)
ebc_out = self.ebc(modified_idlist_features[0])
weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0])

pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1)
return pred.sum(), pred


class TestNegSamplingModule(torch.nn.Module):
"""
Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
Args:
extra_input
has_params
Example:
>>> preproc = TestNegSamplingModule(extra_input)
>>> out = preproc(in)
Returns:
ModelInput
"""

def __init__(
self,
extra_input: ModelInput,
has_params: bool = False,
) -> None:
super().__init__()
self._extra_input = extra_input
if has_params:
self._linear: nn.Module = nn.Linear(30, 30)

def forward(self, input: ModelInput) -> ModelInput:
"""
Appends extra features to model input
Args:
input
Returns:
ModelInput
"""

# merge extra input
modified_input = copy.deepcopy(input)

# dim=0 (batch dimensions) increases by self._extra_input.float_features.shape[0]
modified_input.float_features = torch.concat(
(modified_input.float_features, self._extra_input.float_features), dim=0
)

# stride will be same but features will be joined
modified_input.idlist_features = KeyedJaggedTensor.concat(
[modified_input.idlist_features, self._extra_input.idlist_features]
)
if self._extra_input.idscore_features is not None:
# stride will be smae but features will be joined
modified_input.idscore_features = KeyedJaggedTensor.concat(
# pyre-ignore
[modified_input.idscore_features, self._extra_input.idscore_features]
)

# dim=0 (batch dimensions) increases by self._extra_input.input_label.shape[0]
modified_input.label = torch.concat(
(modified_input.label, self._extra_input.label), dim=0
)

return modified_input


class TestPositionWeightedPreprocModule(torch.nn.Module):
"""
Basic module for testing
Args: None
Example:
>>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
>>> out = preproc(in)
Returns:
ModelInput
"""

def __init__(
self, max_feature_lengths: Dict[str, int], device: torch.device
) -> None:
super().__init__()
self.fp_proc = PositionWeightedProcessor(
max_feature_lengths=max_feature_lengths,
device=device,
)

def forward(self, input: ModelInput) -> ModelInput:
"""
Runs PositionWeightedProcessor
Args:
input
Returns:
ModelInput
"""
modified_input = copy.deepcopy(input)
modified_input.idlist_features = self.fp_proc(modified_input.idlist_features)
return modified_input
Loading

0 comments on commit 54047a5

Please sign in to comment.