Skip to content

Commit

Permalink
api: summary keeps current mode (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
ego-thales committed Nov 29, 2024
1 parent d937bbd commit ac27457
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 30 deletions.
1 change: 0 additions & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ lint.ignore = [
"PLW0602", # Using global for `_cached_forward_pass` but no assignment is done
"PLW0603", # Using the global statement to update `_cached_forward_pass` is discouraged
"PLW2901", # `for` loop variable `name` overwritten by assignment target
"SIM108", # [*] Use ternary operator `model_mode = Mode.EVAL if mode is None else Mode(mode)` instead of `if`-`else`-block
"SLF001", # Private member accessed: `_modules`
"TCH002", # Move third-party import into a type-checking block
"TRY004", # Prefer `TypeError` exception for invalid type
Expand Down
4 changes: 2 additions & 2 deletions tests/torchinfo_xl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_eval_order_doesnt_matter() -> None:
model2 = torchvision.models.resnet18(
weights=torchvision.models.ResNet18_Weights.DEFAULT
)
summary(model2, input_size=input_size)
summary(model2, input_size=input_size, mode="eval")
model2.eval()
with torch.inference_mode():
output2 = model2(input_tensor)
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_tmva_net_column_totals() -> None:
def test_google() -> None:
google_net = torchvision.models.googlenet(init_weights=False)

summary(google_net, (1, 3, 112, 112), depth=7)
summary(google_net, (1, 3, 112, 112), depth=7, mode="eval")

# Check googlenet in training mode since InceptionAux layers are used in
# forward-prop in train mode but not in eval mode.
Expand Down
3 changes: 1 addition & 2 deletions torchinfo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from .enums import ColumnSettings, Mode, RowSettings, Units, Verbosity
from .enums import ColumnSettings, RowSettings, Units, Verbosity
from .model_statistics import ModelStatistics
from .torchinfo import summary

__all__ = (
"ColumnSettings",
"Mode",
"ModelStatistics",
"RowSettings",
"Units",
Expand Down
10 changes: 0 additions & 10 deletions torchinfo/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,6 @@
from enum import Enum, IntEnum, unique


@unique
class Mode(str, Enum):
"""Enum containing all model modes."""

__slots__ = ()

TRAIN = "train"
EVAL = "eval"


@unique
class RowSettings(str, Enum):
"""Enum containing all available row settings."""
Expand Down
26 changes: 11 additions & 15 deletions torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.jit import ScriptModule
from torch.utils.hooks import RemovableHandle

from .enums import ColumnSettings, Mode, RowSettings, Verbosity
from .enums import ColumnSettings, RowSettings, Verbosity
from .formatting import FormattingOptions
from .layer_info import LayerInfo, get_children_layers, prod
from .model_statistics import ModelStatistics
Expand Down Expand Up @@ -155,10 +155,11 @@ class name as the key. If the forward pass is an expensive operation,
also specify the types of each parameter here.
Default: None
mode (str)
Either "train" or "eval", which determines whether we call
model.train() or model.eval() before calling summary().
Default: "eval".
mode (str | None)
One of None, "eval" or "train". If not None, summary() will call either
mode.eval() or mode.train() (respectively) before processing the model.
In any case, original model mode is restored at the end.
Default: None
row_settings (Iterable[str]):
Specify which features to show in a row. Currently supported: (
Expand Down Expand Up @@ -198,11 +199,6 @@ class name as the key. If the forward pass is an expensive operation,
else:
rows = {RowSettings(name) for name in row_settings}

if mode is None:
model_mode = Mode.EVAL
else:
model_mode = Mode(mode)

if verbose is None:
verbose = 0 if hasattr(sys, "ps1") and sys.ps1 else 1

Expand All @@ -223,7 +219,7 @@ class name as the key. If the forward pass is an expensive operation,
input_data, input_size, batch_dim, device, dtypes
)
summary_list = forward_pass(
model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
model, x, batch_dim, cache_forward_pass, device, mode, **kwargs
)
formatting = FormattingOptions(depth, verbose, columns, col_width, rows)
results = ModelStatistics(
Expand Down Expand Up @@ -265,7 +261,7 @@ def forward_pass(
batch_dim: int | None,
cache_forward_pass: bool,
device: torch.device | None,
mode: Mode,
mode: str | None,
**kwargs: Any,
) -> list[LayerInfo]:
"""Perform a forward pass on the model using forward hooks."""
Expand All @@ -282,11 +278,11 @@ def forward_pass(
kwargs = set_device(kwargs, device)
saved_model_mode = model.training
try:
if mode == Mode.TRAIN:
if mode == "train":
model.train()
elif mode == Mode.EVAL:
elif mode == "eval":
model.eval()
else:
elif mode is not None:
raise RuntimeError(
f"Specified model mode ({list(Mode)}) not recognized: {mode}"
)
Expand Down

0 comments on commit ac27457

Please sign in to comment.