Skip to content

Commit

Permalink
Rename preproc to postproc for pipelining
Browse files Browse the repository at this point in the history
Summary: The data transformation that happens during model fwd invocation should be post-processing, not pre-processing. renaming accordingly.

Reviewed By: dstaay-fb

Differential Revision: D67756024
  • Loading branch information
sarckk authored and facebook-github-bot committed Jan 2, 2025
1 parent 455de88 commit c61b298
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 318 deletions.
54 changes: 27 additions & 27 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,7 @@ def __init__(
max_feature_lengths: Optional[Dict[str, int]] = None,
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
over_arch_clazz: Type[nn.Module] = TestOverArch,
preproc_module: Optional[nn.Module] = None,
postproc_module: Optional[nn.Module] = None,
) -> None:
super().__init__(
tables=cast(List[BaseEmbeddingConfig], tables),
Expand Down Expand Up @@ -1229,7 +1229,7 @@ def __init__(
"dummy_ones",
torch.ones(1, device=dense_device),
)
self.preproc_module = preproc_module
self.postproc_module = postproc_module

def sparse_forward(self, input: ModelInput) -> KeyedTensor:
return self.sparse(
Expand All @@ -1256,8 +1256,8 @@ def forward(
self,
input: ModelInput,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.preproc_module:
input = self.preproc_module(input)
if self.postproc_module:
input = self.postproc_module(input)
return self.dense_forward(input, self.sparse_forward(input))


Expand Down Expand Up @@ -1749,18 +1749,18 @@ def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:

class TestModelWithPreproc(nn.Module):
"""
Basic module with up to 3 preproc modules:
- preproc on idlist_features for non-weighted EBC
- preproc on idscore_features for weighted EBC
- optional preproc on model input shared by both EBCs
Basic module with up to 3 postproc modules:
- postproc on idlist_features for non-weighted EBC
- postproc on idscore_features for weighted EBC
- optional postproc on model input shared by both EBCs
Args:
tables,
weighted_tables,
device,
preproc_module,
postproc_module,
num_float_features,
run_preproc_inline,
run_postproc_inline,
Example:
>>> TestModelWithPreproc(tables, weighted_tables, device)
Expand All @@ -1774,9 +1774,9 @@ def __init__(
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
device: torch.device,
preproc_module: Optional[nn.Module] = None,
postproc_module: Optional[nn.Module] = None,
num_float_features: int = 10,
run_preproc_inline: bool = False,
run_postproc_inline: bool = False,
) -> None:
super().__init__()
self.dense = TestDenseArch(num_float_features, device)
Expand All @@ -1790,17 +1790,17 @@ def __init__(
is_weighted=True,
device=device,
)
self.preproc_nonweighted = TestPreprocNonWeighted()
self.preproc_weighted = TestPreprocWeighted()
self._preproc_module = preproc_module
self._run_preproc_inline = run_preproc_inline
self.postproc_nonweighted = TestPreprocNonWeighted()
self.postproc_weighted = TestPreprocWeighted()
self._postproc_module = postproc_module
self._run_postproc_inline = run_postproc_inline

def forward(
self,
input: ModelInput,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Runs preprco for EBC and weighted EBC, optionally runs preproc for input
Runs preprco for EBC and weighted EBC, optionally runs postproc for input
Args:
input
Expand All @@ -1809,20 +1809,20 @@ def forward(
"""
modified_input = input

if self._preproc_module is not None:
modified_input = self._preproc_module(modified_input)
elif self._run_preproc_inline:
if self._postproc_module is not None:
modified_input = self._postproc_module(modified_input)
elif self._run_postproc_inline:
idlist_features = modified_input.idlist_features
modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
idlist_features.keys(), # pyre-ignore [6]
idlist_features.values(), # pyre-ignore [6]
idlist_features.lengths(), # pyre-ignore [16]
)

modified_idlist_features = self.preproc_nonweighted(
modified_idlist_features = self.postproc_nonweighted(
modified_input.idlist_features
)
modified_idscore_features = self.preproc_weighted(
modified_idscore_features = self.postproc_weighted(
modified_input.idscore_features
)
ebc_out = self.ebc(modified_idlist_features[0])
Expand All @@ -1834,15 +1834,15 @@ def forward(

class TestNegSamplingModule(torch.nn.Module):
"""
Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing
Args:
extra_input
has_params
Example:
>>> preproc = TestNegSamplingModule(extra_input)
>>> out = preproc(in)
>>> postproc = TestNegSamplingModule(extra_input)
>>> out = postproc(in)
Returns:
ModelInput
Expand Down Expand Up @@ -1906,8 +1906,8 @@ class TestPositionWeightedPreprocModule(torch.nn.Module):
Args: None
Example:
>>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
>>> out = preproc(in)
>>> postproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
>>> out = postproc(in)
Returns:
ModelInput
"""
Expand Down
Loading

0 comments on commit c61b298

Please sign in to comment.