diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 011ea6383..050fde40e 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -358,9 +358,6 @@ def purge(self) -> None: class CommOpGradientScaling(torch.autograd.Function): - # user override: inline autograd.Function is safe to trace since only tensor mutations / no global state - _compiled_autograd_should_lift = False - @staticmethod # pyre-ignore def forward( @@ -501,23 +498,16 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool: "If you don't turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n" ) if hasattr(emb_op.emb_module, "prefetch"): - if isinstance(emb_op.emb_module, SSDTableBatchedEmbeddingBags): - emb_op.emb_module.prefetch( - indices=features.values(), - offsets=features.offsets(), - forward_stream=forward_stream, - ) - else: - emb_op.emb_module.prefetch( - indices=features.values(), - offsets=features.offsets(), - forward_stream=forward_stream, - batch_size_per_feature_per_rank=( - features.stride_per_key_per_rank() - if features.variable_stride_per_key() - else None - ), - ) + emb_op.emb_module.prefetch( + indices=features.values(), + offsets=features.offsets(), + forward_stream=forward_stream, + batch_size_per_feature_per_rank=( + features.stride_per_key_per_rank() + if features.variable_stride_per_key() + else None + ), + ) def _merge_variable_batch_embeddings( self, embeddings: List[torch.Tensor], splits: List[List[int]] diff --git a/torchrec/distributed/fused_params.py b/torchrec/distributed/fused_params.py index 171f94cb2..26af33938 100644 --- a/torchrec/distributed/fused_params.py +++ b/torchrec/distributed/fused_params.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional import torch @@ -24,6 +24,10 @@ FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment" FUSED_PARAM_BOUNDS_CHECK_MODE: str = "__register_tbe_bounds_check_mode" +# Force lengths to offsets conversion before TBE lookup. Helps with performance +# with certain ways to split models. +FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup" + class TBEToRegisterMixIn: def get_tbes_to_register( @@ -68,6 +72,18 @@ def fused_param_bounds_check_mode( return fused_params[FUSED_PARAM_BOUNDS_CHECK_MODE] +def fused_param_lengths_to_offsets_lookup( + fused_params: Optional[Dict[str, Any]] +) -> bool: + if ( + fused_params is None + or FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP not in fused_params + ): + return False + else: + return fused_params[FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP] + + def is_fused_param_quant_state_dict_split_scale_bias( fused_params: Optional[Dict[str, Any]] ) -> bool: @@ -93,5 +109,7 @@ def tbe_fused_params( fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT) if FUSED_PARAM_BOUNDS_CHECK_MODE in fused_params_for_tbe: fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE) + if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe: + fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP) return fused_params_for_tbe diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 9b230103e..68f799652 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -33,6 +33,7 @@ ) from torchrec.distributed.fused_params import ( fused_param_bounds_check_mode, + fused_param_lengths_to_offsets_lookup, is_fused_param_quant_state_dict_split_scale_bias, is_fused_param_register_tbe, tbe_fused_params, @@ -171,6 +172,19 @@ def _unwrap_kjt_for_cpu( return indices, offsets, None +@torch.fx.wrap +def _unwrap_kjt_lengths( + features: KeyedJaggedTensor, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + indices = features.values() + lengths = features.lengths() + return ( + indices.int(), + lengths.int(), + features.weights_or_none(), + ) + + @torch.fx.wrap def _unwrap_optional_tensor( tensor: Optional[torch.Tensor], @@ -180,6 +194,26 @@ def _unwrap_optional_tensor( return tensor +class IntNBitTableBatchedEmbeddingBagsCodegenWithLength( + IntNBitTableBatchedEmbeddingBagsCodegen +): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # pyre-ignore Inconsistent override [14] + def forward( + self, + indices: torch.Tensor, + lengths: torch.Tensor, + per_sample_weights: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self._forward_impl( + indices=indices, + offsets=(torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)), + per_sample_weights=per_sample_weights, + ) + + class QuantBatchedEmbeddingBag( BaseBatchedEmbeddingBag[ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] @@ -237,22 +271,27 @@ def __init__( ) ) - self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( - IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=embedding_specs, - device=device, - pooling_mode=self._pooling, - feature_table_map=self._feature_table_map, - row_alignment=self._tbe_row_alignment, - uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue - bounds_check_mode=( - bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING - ), - feature_names_per_table=[ - table.feature_names for table in config.embedding_tables - ], - **(tbe_fused_params(fused_params) or {}), - ) + self.lengths_to_tbe: bool = fused_param_lengths_to_offsets_lookup(fused_params) + + if self.lengths_to_tbe: + tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength + else: + tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen + + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz( + embedding_specs=embedding_specs, + device=device, + pooling_mode=self._pooling, + feature_table_map=self._feature_table_map, + row_alignment=self._tbe_row_alignment, + uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + bounds_check_mode=( + bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING + ), + feature_names_per_table=[ + table.feature_names for table in config.embedding_tables + ], + **(tbe_fused_params(fused_params) or {}), ) if device is not None: self._emb_module.initialize_weights() @@ -271,44 +310,50 @@ def get_tbes_to_register( ) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]: return {self._emb_module: self._config} - def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: - # Important: _unwrap_kjt regex for FX tracing TAGing - if self._runtime_device.type == "cpu": - indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu( - features, self._config.is_weighted - ) + def _emb_module_forward( + self, + indices: torch.Tensor, + lengths_or_offsets: torch.Tensor, + weights: Optional[torch.Tensor], + ) -> torch.Tensor: + kwargs = {"indices": indices} + + if self.lengths_to_tbe: + kwargs["lengths"] = lengths_or_offsets else: - indices, offsets, per_sample_weights = _unwrap_kjt(features) + kwargs["offsets"] = lengths_or_offsets if self._is_weighted: - weights = _unwrap_optional_tensor(per_sample_weights) - if self._emb_module_registered: - # Conditional call of .forward function for FX: - # emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module) - # emb_module.forward() does not require registering emb_module in named_modules (FX node call_function) - # For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged. - return self.emb_module( - indices=indices, - offsets=offsets, - per_sample_weights=weights, - ) + kwargs["per_sample_weights"] = _unwrap_optional_tensor(weights) + + if self._emb_module_registered: + # Conditional call of .forward function for FX: + # emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module) + # emb_module.forward() does not require registering emb_module in named_modules (FX node call_function) + # For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged. + return self._emb_module(**kwargs) + else: + return self._emb_module.forward(**kwargs) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + # Important: _unwrap_kjt regex for FX tracing TAGing + lengths, offsets = None, None + if self._runtime_device.type == "cpu": + if self.lengths_to_tbe: + indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features) else: - return self.emb_module.forward( - indices=indices, - offsets=offsets, - per_sample_weights=weights, + indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu( + features, self._config.is_weighted ) else: - if self._emb_module_registered: - return self.emb_module( - indices=indices, - offsets=offsets, - ) + if self.lengths_to_tbe: + indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features) else: - return self.emb_module.forward( - indices=indices, - offsets=offsets, - ) + indices, offsets, per_sample_weights = _unwrap_kjt(features) + + return self._emb_module_forward( + indices, lengths if lengths is not None else offsets, per_sample_weights + ) def named_buffers( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index c14906abc..e1738b598 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -319,7 +319,10 @@ def _test_kjt_input_module( # Need to set as size in order to run a proper forward em_inputs[0][0] = kjt.values().size(0) em_inputs[2][0] = kjt.weights().size(0) - eager_output = symint_wrapper(*em_inputs) + + if not kjt.values().is_meta: + eager_output = symint_wrapper(*em_inputs) + pt2_ir = torch.export.export( symint_wrapper, em_inputs, {}, strict=False ) @@ -504,6 +507,28 @@ def forward(self, kjt: KeyedJaggedTensor): test_pt2_ir_export=True, ) + def test_kjt_length_per_key_meta(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + return kjt.length_per_key() + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + kjt = kjt.to("meta") + + # calling forward on meta inputs once traced should error out + # as calculating length_per_key requires a .tolist() call of lengths + self.assertRaisesRegex( + RuntimeError, + r".*Tensor\.item\(\) cannot be called on meta tensors.*", + lambda: self._test_kjt_input_module( + M(), + kjt, + (), + test_aot_inductor=False, + test_pt2_ir_export=True, + ), + ) + def test_kjt_offset_per_key(self) -> None: class M(torch.nn.Module): def forward(self, kjt: KeyedJaggedTensor): @@ -629,7 +654,6 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None: local_device="cpu", compute_device="cpu" ) kjt = input_kjts[0] - kjt = kjt.to("meta") sharded_model(kjt.values(), kjt.lengths()) from torch.export import _trace @@ -652,9 +676,6 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None: for n in ep.graph_module.graph.nodes: self.assertFalse("auto_functionalized" in str(n.name)) - # TODO: Fix Unflatten - # torch.export.unflatten(ep) - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1, @@ -665,9 +686,6 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None: local_device="cpu", compute_device="cpu", feature_processor=True ) kjt = input_kjts[0] - kjt = kjt.to("meta") - # Move FP parameters - sharded_model.to("meta") sharded_model(kjt.values(), kjt.lengths()) @@ -690,11 +708,6 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None: for n in ep.graph_module.graph.nodes: self.assertFalse("auto_functionalized" in str(n.name)) - # The nn_module_stack for this model forms a skip connection that looks like: - # a -> a.b -> a.b.c -> a.d - # This is currently not supported by unflatten. - # torch.export.unflatten(ep) - def test_maybe_compute_kjt_to_jt_dict(self) -> None: kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) self._test_kjt_input_module( diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 4734f4cd7..b7a886773 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -650,6 +650,7 @@ class KeyValueParams: stats_reporter_config: Optional[TBEStatsReporterConfig] = None use_passed_in_path: bool = True l2_cache_size: Optional[int] = None + enable_async_update: Optional[bool] = None # Parameter Server (PS) Attributes ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None @@ -672,6 +673,7 @@ def __hash__(self) -> int: self.gather_ssd_cache_stats, self.stats_reporter_config, self.l2_cache_size, + self.enable_async_update, ) ) diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index f2d8903dc..4d136f488 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -26,6 +26,7 @@ from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.fused_params import ( FUSED_PARAM_BOUNDS_CHECK_MODE, + FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP, FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, FUSED_PARAM_REGISTER_TBE_BOOL, ) @@ -82,6 +83,7 @@ def trim_torch_package_prefix_from_typename(typename: str) -> str: FUSED_PARAM_REGISTER_TBE_BOOL: True, FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True, FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE, + FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: False, } DEFAULT_SHARDERS: List[ModuleSharder[torch.nn.Module]] = [ diff --git a/torchrec/inference/tests/test_inference.py b/torchrec/inference/tests/test_inference.py index d0ad0469a..b13c32e9f 100644 --- a/torchrec/inference/tests/test_inference.py +++ b/torchrec/inference/tests/test_inference.py @@ -10,17 +10,22 @@ import unittest from argparse import Namespace +from typing import Any, cast, Dict, List import torch from fbgemm_gpu.split_embedding_configs import SparseType +from torch.fx import symbolic_trace from torchrec import PoolingType from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.distributed.fused_params import FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP from torchrec.distributed.global_settings import set_propogate_device +from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.test_utils.test_model import ( ModelInput, TestOverArchRegroupModule, TestSparseNN, ) +from torchrec.distributed.types import ModuleSharder from torchrec.inference.dlrm_predict import ( create_training_batch, @@ -300,6 +305,62 @@ def test_sharded_quantized_tbe_count(self) -> None: expected_num_embeddings[spec[0]], ) + def test_sharded_quantized_lengths_to_tbe(self) -> None: + set_propogate_device(True) + + fused_params: Dict[str, Any] = {FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: True} + sharders: List[ModuleSharder[torch.nn.Module]] = [ + cast( + ModuleSharder[torch.nn.Module], + QuantEmbeddingBagCollectionSharder(fused_params=fused_params), + ), + ] + + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=torch.device("cpu"), + sparse_device=torch.device("cpu"), + over_arch_clazz=TestOverArchRegroupModule, + ) + + model.eval() + _, local_batch = ModelInput.generate( + batch_size=16, + world_size=1, + num_float_features=10, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + + # with torch.inference_mode(): # TODO: Why does inference mode fail when using different quant data types + output = model(local_batch[0]) + + # Quantize the model and collect quantized weights + quantized_model = quantize_inference_model(model) + quantized_output = quantized_model(local_batch[0]) + table_to_weight = get_table_to_weights_from_tbe(quantized_model) + + # Shard the model, all weights are initialized back to 0, so have to reassign weights + sharded_quant_model, _ = shard_quant_model( + quantized_model, + world_size=1, + compute_device="cpu", + sharding_device="cpu", + sharders=sharders, + ) + assign_weights_to_tbe(quantized_model, table_to_weight) + sharded_quant_output = sharded_quant_model(local_batch[0]) + + # When world_size = 1, we should have 1 TBE per sharded, quantized ebc + self.assertTrue(len(sharded_quant_model.sparse.ebc.tbes) == 1) + self.assertTrue(len(sharded_quant_model.sparse.weighted_ebc.tbes) == 1) + + # Check the weights are close + self.assertTrue(torch.allclose(output, quantized_output, atol=1e-3)) + self.assertTrue(torch.allclose(output, sharded_quant_output, atol=1e-3)) + def test_quantized_tbe_count_different_pooling(self) -> None: set_propogate_device(True) diff --git a/torchrec/metrics/gauc.py b/torchrec/metrics/gauc.py index c509475b9..829f875c0 100644 --- a/torchrec/metrics/gauc.py +++ b/torchrec/metrics/gauc.py @@ -209,8 +209,8 @@ def _compute(self) -> List[MetricComputationReport]: name=MetricName.GAUC_NUM_SAMPLES, metric_prefix=MetricPrefix.LIFETIME, value=compute_window_auc( - self.get_window_state("auc_sum"), - self.get_window_state("num_samples"), + cast(torch.Tensor, self.auc_sum), + cast(torch.Tensor, self.num_samples), )["num_samples"], ), MetricComputationReport( diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 4b5359f0d..15952bfa5 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1172,7 +1172,12 @@ def _maybe_compute_length_per_key( values: Optional[torch.Tensor], ) -> List[int]: if length_per_key is None: - if len(keys) and values is not None and values.is_meta: + if ( + len(keys) + and values is not None + and values.is_meta + and not is_non_strict_exporting() + ): # create dummy lengths per key when on meta device total_length = values.numel() _length = [total_length // len(keys)] * len(keys)