diff --git a/torchtnt/framework/unit.py b/torchtnt/framework/unit.py index 7e470f100d..41074d85d5 100644 --- a/torchtnt/framework/unit.py +++ b/torchtnt/framework/unit.py @@ -13,11 +13,11 @@ from typing import Any, cast, Dict, Generic, Iterator, TypeVar, Union import torch +from pyre_extensions import none_throws from torchtnt.framework._unit_utils import ( _find_optimizers_for_module, _step_requires_iterator, ) - from torchtnt.framework.state import State from torchtnt.utils.lr_scheduler import TLRScheduler from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper @@ -312,6 +312,7 @@ def on_train_epoch_end(self, state: State) -> None: def __init__(self) -> None: super().__init__() self.train_progress = Progress() + self.first_train_batch: TTrainData | None = None def on_train_start(self, state: State) -> None: """Hook called before training starts. @@ -329,6 +330,14 @@ def on_train_epoch_start(self, state: State) -> None: """ pass + @property + def first_train_batch(self) -> TTrainData: + return none_throws(self.first_train_batch) + + @first_train_batch.setter + def first_train_batch(self, data: TTrainData) -> None: + self.first_train_batch = data + @abstractmethod # pyre-fixme[3]: Return annotation cannot be `Any`. def train_step(self, state: State, data: TTrainData) -> Any: