From ac27457e63d39b305de5d569987c6746b45fa9bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89lie=20Goudout?= Date: Fri, 29 Nov 2024 11:06:19 +0100 Subject: [PATCH 1/4] api: summary keeps current mode (#331) --- ruff.toml | 1 - tests/torchinfo_xl_test.py | 4 ++-- torchinfo/__init__.py | 3 +-- torchinfo/enums.py | 10 ---------- torchinfo/torchinfo.py | 26 +++++++++++--------------- 5 files changed, 14 insertions(+), 30 deletions(-) diff --git a/ruff.toml b/ruff.toml index 5a58061..ad11958 100644 --- a/ruff.toml +++ b/ruff.toml @@ -36,7 +36,6 @@ lint.ignore = [ "PLW0602", # Using global for `_cached_forward_pass` but no assignment is done "PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged "PLW2901", # `for` loop variable `name` overwritten by assignment target - "SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block "SLF001", # Private member accessed: `_modules` "TCH002", # Move third-party import into a type-checking block "TRY004", # Prefer `TypeError` exception for invalid type diff --git a/tests/torchinfo_xl_test.py b/tests/torchinfo_xl_test.py index 97581c7..1a9ca46 100644 --- a/tests/torchinfo_xl_test.py +++ b/tests/torchinfo_xl_test.py @@ -57,7 +57,7 @@ def test_eval_order_doesnt_matter() -> None: model2 = torchvision.models.resnet18( weights=torchvision.models.ResNet18_Weights.DEFAULT ) - summary(model2, input_size=input_size) + summary(model2, input_size=input_size, mode="eval") model2.eval() with torch.inference_mode(): output2 = model2(input_tensor) @@ -144,7 +144,7 @@ def test_tmva_net_column_totals() -> None: def test_google() -> None: google_net = torchvision.models.googlenet(init_weights=False) - summary(google_net, (1, 3, 112, 112), depth=7) + summary(google_net, (1, 3, 112, 112), depth=7, mode="eval") # Check googlenet in training mode since InceptionAux layers are used in # forward-prop in train mode but not in eval mode. diff --git a/torchinfo/__init__.py b/torchinfo/__init__.py index 65d1432..36eab38 100644 --- a/torchinfo/__init__.py +++ b/torchinfo/__init__.py @@ -1,10 +1,9 @@ -from .enums import ColumnSettings, Mode, RowSettings, Units, Verbosity +from .enums import ColumnSettings, RowSettings, Units, Verbosity from .model_statistics import ModelStatistics from .torchinfo import summary __all__ = ( "ColumnSettings", - "Mode", "ModelStatistics", "RowSettings", "Units", diff --git a/torchinfo/enums.py b/torchinfo/enums.py index 191de8d..ebc3a29 100644 --- a/torchinfo/enums.py +++ b/torchinfo/enums.py @@ -3,16 +3,6 @@ from enum import Enum, IntEnum, unique -@unique -class Mode(str, Enum): - """Enum containing all model modes.""" - - __slots__ = () - - TRAIN = "train" - EVAL = "eval" - - @unique class RowSettings(str, Enum): """Enum containing all available row settings.""" diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index e9e5823..25ee4d2 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -21,7 +21,7 @@ from torch.jit import ScriptModule from torch.utils.hooks import RemovableHandle -from .enums import ColumnSettings, Mode, RowSettings, Verbosity +from .enums import ColumnSettings, RowSettings, Verbosity from .formatting import FormattingOptions from .layer_info import LayerInfo, get_children_layers, prod from .model_statistics import ModelStatistics @@ -155,10 +155,11 @@ class name as the key. If the forward pass is an expensive operation, also specify the types of each parameter here. Default: None - mode (str) - Either "train" or "eval", which determines whether we call - model.train() or model.eval() before calling summary(). - Default: "eval". + mode (str | None) + One of None, "eval" or "train". If not None, summary() will call either + mode.eval() or mode.train() (respectively) before processing the model. + In any case, original model mode is restored at the end. + Default: None row_settings (Iterable[str]): Specify which features to show in a row. Currently supported: ( @@ -198,11 +199,6 @@ class name as the key. If the forward pass is an expensive operation, else: rows = {RowSettings(name) for name in row_settings} - if mode is None: - model_mode = Mode.EVAL - else: - model_mode = Mode(mode) - if verbose is None: verbose = 0 if hasattr(sys, "ps1") and sys.ps1 else 1 @@ -223,7 +219,7 @@ class name as the key. If the forward pass is an expensive operation, input_data, input_size, batch_dim, device, dtypes ) summary_list = forward_pass( - model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs + model, x, batch_dim, cache_forward_pass, device, mode, **kwargs ) formatting = FormattingOptions(depth, verbose, columns, col_width, rows) results = ModelStatistics( @@ -265,7 +261,7 @@ def forward_pass( batch_dim: int | None, cache_forward_pass: bool, device: torch.device | None, - mode: Mode, + mode: str | None, **kwargs: Any, ) -> list[LayerInfo]: """Perform a forward pass on the model using forward hooks.""" @@ -282,11 +278,11 @@ def forward_pass( kwargs = set_device(kwargs, device) saved_model_mode = model.training try: - if mode == Mode.TRAIN: + if mode == "train": model.train() - elif mode == Mode.EVAL: + elif mode == "eval": model.eval() - else: + elif mode is not None: raise RuntimeError( f"Specified model mode ({list(Mode)}) not recognized: {mode}" ) From 0b1aef72a442dd4a3e6e855c0bce4c9347bb3c96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89lie=20Goudout?= Date: Fri, 29 Nov 2024 11:27:19 +0100 Subject: [PATCH 2/4] fix: forgot to rm 'Mode' from error msg --- torchinfo/torchinfo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index 25ee4d2..715f3c3 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -284,7 +284,7 @@ def forward_pass( model.eval() elif mode is not None: raise RuntimeError( - f"Specified model mode ({list(Mode)}) not recognized: {mode}" + f"Specified model mode should be None, 'eval' or 'train' (got {mode})!" ) with torch.no_grad(): From 1009053d3b9f463ab4ded43a1c4e70e0e9dc5b8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89lie=20Goudout?= Date: Mon, 2 Dec 2024 15:59:39 +0100 Subject: [PATCH 3/4] dev: revert Mode Enum removal -> now 'same' in place of None --- ruff.toml | 1 + torchinfo/__init__.py | 3 ++- torchinfo/enums.py | 11 +++++++++++ torchinfo/torchinfo.py | 28 +++++++++++++++------------- 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/ruff.toml b/ruff.toml index ad11958..5a58061 100644 --- a/ruff.toml +++ b/ruff.toml @@ -36,6 +36,7 @@ lint.ignore = [ "PLW0602", # Using global for `_cached_forward_pass` but no assignment is done "PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged "PLW2901", # `for` loop variable `name` overwritten by assignment target + "SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block "SLF001", # Private member accessed: `_modules` "TCH002", # Move third-party import into a type-checking block "TRY004", # Prefer `TypeError` exception for invalid type diff --git a/torchinfo/__init__.py b/torchinfo/__init__.py index 36eab38..65d1432 100644 --- a/torchinfo/__init__.py +++ b/torchinfo/__init__.py @@ -1,9 +1,10 @@ -from .enums import ColumnSettings, RowSettings, Units, Verbosity +from .enums import ColumnSettings, Mode, RowSettings, Units, Verbosity from .model_statistics import ModelStatistics from .torchinfo import summary __all__ = ( "ColumnSettings", + "Mode", "ModelStatistics", "RowSettings", "Units", diff --git a/torchinfo/enums.py b/torchinfo/enums.py index ebc3a29..336bd78 100644 --- a/torchinfo/enums.py +++ b/torchinfo/enums.py @@ -3,6 +3,17 @@ from enum import Enum, IntEnum, unique +@unique +class Mode(str, Enum): + """Enum containing all model modes.""" + + __slots__ = () + + TRAIN = "train" + EVAL = "eval" + SAME = "same" + + @unique class RowSettings(str, Enum): """Enum containing all available row settings.""" diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index 715f3c3..c0b6bb5 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -21,7 +21,7 @@ from torch.jit import ScriptModule from torch.utils.hooks import RemovableHandle -from .enums import ColumnSettings, RowSettings, Verbosity +from .enums import ColumnSettings, Mode, RowSettings, Verbosity from .formatting import FormattingOptions from .layer_info import LayerInfo, get_children_layers, prod from .model_statistics import ModelStatistics @@ -62,7 +62,7 @@ def summary( depth: int = 3, device: torch.device | str | None = None, dtypes: list[torch.dtype] | None = None, - mode: str | None = None, + mode: str = "same", row_settings: Iterable[str] | None = None, verbose: int | None = None, **kwargs: Any, @@ -155,11 +155,11 @@ class name as the key. If the forward pass is an expensive operation, also specify the types of each parameter here. Default: None - mode (str | None) - One of None, "eval" or "train". If not None, summary() will call either - mode.eval() or mode.train() (respectively) before processing the model. - In any case, original model mode is restored at the end. - Default: None + mode (str) + Either "train", "eval" or "same", which determines whether we call + model.train() or model.eval() before calling summary(). In any case, + original model mode is restored at the end. + Default: "same". row_settings (Iterable[str]): Specify which features to show in a row. Currently supported: ( @@ -199,6 +199,8 @@ class name as the key. If the forward pass is an expensive operation, else: rows = {RowSettings(name) for name in row_settings} + model_mode = Mode(mode) + if verbose is None: verbose = 0 if hasattr(sys, "ps1") and sys.ps1 else 1 @@ -219,7 +221,7 @@ class name as the key. If the forward pass is an expensive operation, input_data, input_size, batch_dim, device, dtypes ) summary_list = forward_pass( - model, x, batch_dim, cache_forward_pass, device, mode, **kwargs + model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs ) formatting = FormattingOptions(depth, verbose, columns, col_width, rows) results = ModelStatistics( @@ -261,7 +263,7 @@ def forward_pass( batch_dim: int | None, cache_forward_pass: bool, device: torch.device | None, - mode: str | None, + mode: Mode, **kwargs: Any, ) -> list[LayerInfo]: """Perform a forward pass on the model using forward hooks.""" @@ -278,13 +280,13 @@ def forward_pass( kwargs = set_device(kwargs, device) saved_model_mode = model.training try: - if mode == "train": + if mode == Mode.TRAIN: model.train() - elif mode == "eval": + elif mode == Mode.EVAL: model.eval() - elif mode is not None: + elif mode != Mode.SAME: raise RuntimeError( - f"Specified model mode should be None, 'eval' or 'train' (got {mode})!" + f"Specified model mode ({list(Mode)}) not recognized: {mode}" ) with torch.no_grad(): From 48ee0a7b333dcf025dbc5e9430dd8086f52641c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89lie=20Goudout?= Date: Mon, 2 Dec 2024 16:00:42 +0100 Subject: [PATCH 4/4] readme: Update summary doc ('same' mode) --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 79768e6..9696912 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ def summary( depth: int = 3, device: Optional[torch.device] = None, dtypes: Optional[List[torch.dtype]] = None, - mode: str | None = None, + mode: str = "same", row_settings: Optional[Iterable[str]] = None, verbose: int = 1, **kwargs: Any, @@ -198,9 +198,10 @@ Args: Default: None mode (str) - Either "train" or "eval", which determines whether we call - model.train() or model.eval() before calling summary(). - Default: "eval". + Either "train", "eval" or "same", which determines whether we call + model.train() or model.eval() before calling summary(). In any case, + original model mode is restored at the end. + Default: "same". row_settings (Iterable[str]): Specify which features to show in a row. Currently supported: (