diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 8b917cce1..0f60362b6 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -598,7 +598,7 @@ def named_buffers( yield key, param @property - def fused_optimizer(self) -> KeyedOptimizer: + def fused_optimizer(self) -> CombinedOptimizer: return self._optim @property