diff --git a/README.md b/README.md index 79768e6..9696912 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ def summary( depth: int = 3, device: Optional[torch.device] = None, dtypes: Optional[List[torch.dtype]] = None, - mode: str | None = None, + mode: str = "same", row_settings: Optional[Iterable[str]] = None, verbose: int = 1, **kwargs: Any, @@ -198,9 +198,10 @@ Args: Default: None mode (str) - Either "train" or "eval", which determines whether we call - model.train() or model.eval() before calling summary(). - Default: "eval". + Either "train", "eval" or "same", which determines whether we call + model.train() or model.eval() before calling summary(). In any case, + original model mode is restored at the end. + Default: "same". row_settings (Iterable[str]): Specify which features to show in a row. Currently supported: ( diff --git a/tests/torchinfo_xl_test.py b/tests/torchinfo_xl_test.py index 97581c7..1a9ca46 100644 --- a/tests/torchinfo_xl_test.py +++ b/tests/torchinfo_xl_test.py @@ -57,7 +57,7 @@ def test_eval_order_doesnt_matter() -> None: model2 = torchvision.models.resnet18( weights=torchvision.models.ResNet18_Weights.DEFAULT ) - summary(model2, input_size=input_size) + summary(model2, input_size=input_size, mode="eval") model2.eval() with torch.inference_mode(): output2 = model2(input_tensor) @@ -144,7 +144,7 @@ def test_tmva_net_column_totals() -> None: def test_google() -> None: google_net = torchvision.models.googlenet(init_weights=False) - summary(google_net, (1, 3, 112, 112), depth=7) + summary(google_net, (1, 3, 112, 112), depth=7, mode="eval") # Check googlenet in training mode since InceptionAux layers are used in # forward-prop in train mode but not in eval mode. diff --git a/torchinfo/enums.py b/torchinfo/enums.py index 191de8d..336bd78 100644 --- a/torchinfo/enums.py +++ b/torchinfo/enums.py @@ -11,6 +11,7 @@ class Mode(str, Enum): TRAIN = "train" EVAL = "eval" + SAME = "same" @unique diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index e9e5823..c0b6bb5 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -62,7 +62,7 @@ def summary( depth: int = 3, device: torch.device | str | None = None, dtypes: list[torch.dtype] | None = None, - mode: str | None = None, + mode: str = "same", row_settings: Iterable[str] | None = None, verbose: int | None = None, **kwargs: Any, @@ -156,9 +156,10 @@ class name as the key. If the forward pass is an expensive operation, Default: None mode (str) - Either "train" or "eval", which determines whether we call - model.train() or model.eval() before calling summary(). - Default: "eval". + Either "train", "eval" or "same", which determines whether we call + model.train() or model.eval() before calling summary(). In any case, + original model mode is restored at the end. + Default: "same". row_settings (Iterable[str]): Specify which features to show in a row. Currently supported: ( @@ -198,10 +199,7 @@ class name as the key. If the forward pass is an expensive operation, else: rows = {RowSettings(name) for name in row_settings} - if mode is None: - model_mode = Mode.EVAL - else: - model_mode = Mode(mode) + model_mode = Mode(mode) if verbose is None: verbose = 0 if hasattr(sys, "ps1") and sys.ps1 else 1 @@ -286,7 +284,7 @@ def forward_pass( model.train() elif mode == Mode.EVAL: model.eval() - else: + elif mode != Mode.SAME: raise RuntimeError( f"Specified model mode ({list(Mode)}) not recognized: {mode}" )