Skip to content

Commit 9d62d3c

Browse files
committed
Create ColumnSettings and RowSettings enums
1 parent cc5b7f6 commit 9d62d3c

7 files changed

Lines changed: 96 additions & 82 deletions

File tree

tests/conftest.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from _pytest.config.argparsing import Parser
99

1010
from torchinfo import ModelStatistics
11-
from torchinfo.formatting import HEADER_TITLES
11+
from torchinfo.formatting import HEADER_TITLES, ColumnSettings
1212
from torchinfo.torchinfo import clear_cached_forward_pass
1313

1414

@@ -28,12 +28,6 @@ def verify_capsys(
2828
return
2929

3030
test_name = request.node.name.replace("test_", "")
31-
if sys.version_info < (3, 7) and test_name == "lstm":
32-
warnings.warn(
33-
"Verbose output is not determininstic because dictionaries "
34-
"are not necessarily ordered in versions before Python 3.7."
35-
)
36-
return
3731
if sys.version_info < (3, 8) and test_name == "tmva_net_column_totals":
3832
warnings.warn(
3933
"sys.getsizeof can return different results on earlier Python versions."
@@ -67,7 +61,7 @@ def verify_output_str(output: str, filename: str) -> None:
6761
if output != expected:
6862
print(f"Expected:\n{expected}\nGot:\n{output}")
6963
assert output == expected
70-
for category in ("num_params", "mult_adds"):
64+
for category in (ColumnSettings.NUM_PARAMS, ColumnSettings.MULT_ADDS):
7165
assert_sum_column_totals_match(output, category)
7266

7367

@@ -85,7 +79,7 @@ def get_column_value_for_row(line: str, offset: int) -> int:
8579
return int(col_value.replace(",", "").replace("(", "").replace(")", ""))
8680

8781

88-
def assert_sum_column_totals_match(output: str, category: str) -> None:
82+
def assert_sum_column_totals_match(output: str, category: ColumnSettings) -> None:
8983
lines = output.replace("=", "").split("\n\n")
9084
header_row = lines[0].strip()
9185
offset = header_row.find(HEADER_TITLES[category])
@@ -95,10 +89,10 @@ def assert_sum_column_totals_match(output: str, category: str) -> None:
9589
calculated_total = sum(get_column_value_for_row(line, offset) for line in layers)
9690
results = lines[2].split("\n")
9791

98-
if category == "num_params":
92+
if category == ColumnSettings.NUM_PARAMS:
9993
total_params = results[0].split(":")[1].replace(",", "")
10094
assert calculated_total == int(total_params)
101-
elif category == "mult_adds":
95+
elif category == ColumnSettings.MULT_ADDS:
10296
total_mult_adds = results[-1].split(":")[1].replace(",", "")
10397
assert float(
10498
f"{ModelStatistics.to_readable(calculated_total)[1]:0.2f}"

tests/torchinfo_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
SingleInputNet,
2929
)
3030
from tests.fixtures.tmva_net import TMVANet # type: ignore[attr-defined]
31-
from torchinfo import ALL_COLUMN_SETTINGS, summary
31+
from torchinfo import ColumnSettings, summary
3232

3333

3434
def test_basic_summary() -> None:
@@ -221,7 +221,7 @@ def test_parameter_list() -> None:
221221
model,
222222
input_size=(100, 100),
223223
verbose=2,
224-
col_names=ALL_COLUMN_SETTINGS,
224+
col_names=list(ColumnSettings),
225225
col_width=15,
226226
)
227227

torchinfo/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
""" torchinfo """
2-
from .formatting import ALL_COLUMN_SETTINGS, ALL_ROW_SETTINGS
2+
from .enums import ColumnSettings, RowSettings
33
from .model_statistics import ModelStatistics
44
from .torchinfo import summary
55

6-
__all__ = ("ModelStatistics", "summary", "ALL_COLUMN_SETTINGS", "ALL_ROW_SETTINGS")
6+
__all__ = ("ModelStatistics", "summary", "ColumnSettings", "RowSettings")
77
__version__ = "1.6.0"

torchinfo/enums.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
""" constants.py """
2+
from __future__ import annotations
3+
4+
from enum import Enum, unique
5+
6+
7+
@unique
8+
class RowSettings(Enum):
9+
"""Enum containing all available row settings."""
10+
11+
DEPTH = "depth"
12+
VAR_NAMES = "var_names"
13+
14+
15+
@unique
16+
class ColumnSettings(Enum):
17+
"""Enum containing all available column settings."""
18+
19+
KERNEL_SIZE = "kernel_size"
20+
INPUT_SIZE = "input_size"
21+
OUTPUT_SIZE = "output_size"
22+
NUM_PARAMS = "num_params"
23+
MULT_ADDS = "mult_adds"
24+
25+
26+
@unique
27+
class Verbosity(Enum):
28+
"""Contains verbosity levels."""
29+
30+
QUIET, DEFAULT, VERBOSE = 0, 1, 2

torchinfo/formatting.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,30 @@
22
from __future__ import annotations
33

44
import math
5-
from enum import Enum, unique
6-
from typing import Any, Iterable
5+
from typing import Any
76

7+
from .enums import ColumnSettings, RowSettings, Verbosity
88
from .layer_info import LayerInfo
99

10-
ALL_ROW_SETTINGS = ("depth", "var_names")
11-
ALL_COLUMN_SETTINGS = (
12-
"kernel_size",
13-
"input_size",
14-
"output_size",
15-
"num_params",
16-
"mult_adds",
17-
)
1810
HEADER_TITLES = {
19-
"kernel_size": "Kernel Shape",
20-
"input_size": "Input Shape",
21-
"output_size": "Output Shape",
22-
"num_params": "Param #",
23-
"mult_adds": "Mult-Adds",
11+
ColumnSettings.KERNEL_SIZE: "Kernel Shape",
12+
ColumnSettings.INPUT_SIZE: "Input Shape",
13+
ColumnSettings.OUTPUT_SIZE: "Output Shape",
14+
ColumnSettings.NUM_PARAMS: "Param #",
15+
ColumnSettings.MULT_ADDS: "Mult-Adds",
2416
}
2517

2618

27-
@unique
28-
class Verbosity(Enum):
29-
"""Contains verbosity levels."""
30-
31-
QUIET, DEFAULT, VERBOSE = 0, 1, 2
32-
33-
3419
class FormattingOptions:
3520
"""Class that holds information about formatting the table output."""
3621

3722
def __init__(
3823
self,
3924
max_depth: int,
4025
verbose: int,
41-
col_names: Iterable[str],
26+
col_names: tuple[ColumnSettings, ...],
4227
col_width: int,
43-
row_settings: Iterable[str],
28+
row_settings: set[RowSettings],
4429
) -> None:
4530
self.max_depth = max_depth
4631
self.verbose = verbose
@@ -49,8 +34,8 @@ def __init__(
4934
self.row_settings = row_settings
5035

5136
self.layer_name_width = 40
52-
self.show_var_name = "var_names" in self.row_settings
53-
self.show_depth = "depth" in self.row_settings
37+
self.show_var_name = RowSettings.VAR_NAMES in self.row_settings
38+
self.show_depth = RowSettings.DEPTH in self.row_settings
5439

5540
@staticmethod
5641
def get_start_str(depth: int) -> str:
@@ -95,7 +80,7 @@ def get_total_width(self) -> int:
9580
"""Calculate the total width of all lines in the table."""
9681
return len(tuple(self.col_names)) * self.col_width + self.layer_name_width
9782

98-
def format_row(self, layer_name: str, row_values: dict[str, str]) -> str:
83+
def format_row(self, layer_name: str, row_values: dict[ColumnSettings, str]) -> str:
9984
"""Get the string representation of a single layer of the model."""
10085
info_to_use = [row_values.get(row_type, "") for row_type in self.col_names]
10186
new_line = f"{layer_name:<{self.layer_name_width}} "
@@ -118,16 +103,18 @@ def layer_info_to_row(
118103
children_layers: list[LayerInfo],
119104
) -> str:
120105
"""Convert layer_info to string representation of a row."""
121-
row_values = {
122-
"kernel_size": self.str_(layer_info.kernel_size),
123-
"input_size": self.str_(layer_info.input_size),
124-
"output_size": self.str_(layer_info.output_size),
125-
"num_params": layer_info.num_params_to_str(reached_max_depth),
126-
"mult_adds": layer_info.macs_to_str(reached_max_depth, children_layers),
106+
values_for_row = {
107+
ColumnSettings.KERNEL_SIZE: self.str_(layer_info.kernel_size),
108+
ColumnSettings.INPUT_SIZE: self.str_(layer_info.input_size),
109+
ColumnSettings.OUTPUT_SIZE: self.str_(layer_info.output_size),
110+
ColumnSettings.NUM_PARAMS: layer_info.num_params_to_str(reached_max_depth),
111+
ColumnSettings.MULT_ADDS: layer_info.macs_to_str(
112+
reached_max_depth, children_layers
113+
),
127114
}
128115
start_str = self.get_start_str(layer_info.depth)
129116
layer_name = layer_info.get_layer_name(self.show_var_name, self.show_depth)
130-
new_line = self.format_row(f"{start_str}{layer_name}", row_values)
117+
new_line = self.format_row(f"{start_str}{layer_name}", values_for_row)
131118

132119
if self.verbose == Verbosity.VERBOSE.value:
133120
for inner_name, inner_layer_info in layer_info.inner_layers.items():

torchinfo/layer_info.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from torch import nn
88
from torch.jit import ScriptModule
99

10+
from .enums import ColumnSettings
11+
1012
DETECTED_INPUT_OUTPUT_TYPES = Union[
1113
Sequence[Any], Dict[Any, torch.Tensor], torch.Tensor
1214
]
@@ -39,8 +41,8 @@ def __init__(
3941
if isinstance(module, ScriptModule)
4042
else module.__class__.__name__
4143
)
42-
# {layer name: {row_name: row_value}}
43-
self.inner_layers: dict[str, dict[str, Any]] = {}
44+
# {layer name: {col_name: value_for_row}}
45+
self.inner_layers: dict[str, dict[ColumnSettings, Any]] = {}
4446
self.depth = depth
4547
self.depth_index = depth_index
4648
self.executed = False
@@ -163,13 +165,13 @@ def calculate_num_params(self) -> None:
163165

164166
# RNN modules have inner weights such as weight_ih_l0
165167
self.inner_layers[name] = {
166-
"kernel_size": str(ksize),
167-
"num_params": f"├─{cur_params:,}",
168+
ColumnSettings.KERNEL_SIZE: str(ksize),
169+
ColumnSettings.NUM_PARAMS: f"├─{cur_params:,}",
168170
}
169171
if self.inner_layers:
170172
self.inner_layers[name][
171-
"num_params"
172-
] = f"└─{self.inner_layers[name]['num_params'][2:]}"
173+
ColumnSettings.NUM_PARAMS
174+
] = f"└─{self.inner_layers[name][ColumnSettings.NUM_PARAMS][2:]}"
173175

174176
def calculate_macs(self) -> None:
175177
"""

torchinfo/torchinfo.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,8 @@
2020
from torch.jit import ScriptModule
2121
from torch.utils.hooks import RemovableHandle
2222

23-
from .formatting import (
24-
ALL_COLUMN_SETTINGS,
25-
ALL_ROW_SETTINGS,
26-
FormattingOptions,
27-
Verbosity,
28-
)
23+
from .enums import ColumnSettings, RowSettings, Verbosity
24+
from .formatting import FormattingOptions
2925
from .layer_info import LayerInfo
3026
from .model_statistics import ModelStatistics
3127

@@ -40,8 +36,13 @@
4036
INPUT_SIZE_TYPE = Sequence[Union[int, Sequence[Any], torch.Size]]
4137
CORRECTED_INPUT_SIZE_TYPE = List[Union[Sequence[Any], torch.Size]]
4238

43-
DEFAULT_COLUMN_NAMES = ("output_size", "num_params")
44-
DEFAULT_ROW_SETTINGS = ("depth",)
39+
DEFAULT_COLUMN_NAMES = (ColumnSettings.OUTPUT_SIZE, ColumnSettings.NUM_PARAMS)
40+
DEFAULT_ROW_SETTINGS = {RowSettings.DEPTH}
41+
REQUIRES_INPUT = {
42+
ColumnSettings.INPUT_SIZE,
43+
ColumnSettings.OUTPUT_SIZE,
44+
ColumnSettings.MULT_ADDS,
45+
}
4546

4647
_cached_forward_pass: dict[str, list[LayerInfo]] = {}
4748

@@ -166,12 +167,19 @@ class name as the key. If the forward pass is an expensive operation,
166167
See torchinfo/model_statistics.py for more information.
167168
"""
168169
input_data_specified = input_data is not None or input_size is not None
169-
170170
if col_names is None:
171-
col_names = DEFAULT_COLUMN_NAMES if input_data_specified else ("num_params",)
171+
columns = (
172+
DEFAULT_COLUMN_NAMES
173+
if input_data_specified
174+
else (ColumnSettings.NUM_PARAMS,)
175+
)
176+
else:
177+
columns = tuple(ColumnSettings(name) for name in col_names)
172178

173179
if row_settings is None:
174-
row_settings = DEFAULT_ROW_SETTINGS
180+
rows = DEFAULT_ROW_SETTINGS
181+
else:
182+
rows = {RowSettings(name) for name in row_settings}
175183

176184
if verbose is None:
177185
# pylint: disable=no-member
@@ -184,17 +192,15 @@ class name as the key. If the forward pass is an expensive operation,
184192
if device is None:
185193
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
186194

187-
validate_user_params(
188-
input_data, input_size, col_names, col_width, row_settings, verbose
189-
)
195+
validate_user_params(input_data, input_size, columns, col_width, verbose)
190196

191197
x, correct_input_size = process_input(
192198
input_data, input_size, batch_dim, device, dtypes
193199
)
194200
summary_list = forward_pass(
195201
model, x, batch_dim, cache_forward_pass, device, **kwargs
196202
)
197-
formatting = FormattingOptions(depth, verbose, col_names, col_width, row_settings)
203+
formatting = FormattingOptions(depth, verbose, columns, col_width, rows)
198204
results = ModelStatistics(
199205
summary_list, correct_input_size, get_total_memory_used(x), formatting
200206
)
@@ -287,9 +293,8 @@ def forward_pass(
287293
def validate_user_params(
288294
input_data: INPUT_DATA_TYPE | None,
289295
input_size: INPUT_SIZE_TYPE | None,
290-
col_names: Iterable[str],
296+
col_names: tuple[ColumnSettings, ...],
291297
col_width: int,
292-
row_settings: Iterable[str],
293298
verbose: int,
294299
) -> None:
295300
"""Raise exceptions if the user's input is invalid."""
@@ -304,16 +309,12 @@ def validate_user_params(
304309
raise RuntimeError("Only one of (input_data, input_size) should be specified.")
305310

306311
neither_input_specified = input_data is None and input_size is None
307-
for col in col_names:
308-
if col not in ALL_COLUMN_SETTINGS:
309-
raise ValueError(f"Column {col} is not a valid column name.")
310-
if neither_input_specified and col not in ("num_params", "kernel_size"):
311-
raise ValueError(
312-
f"You must pass input_data or input_size in order to use column {col}"
313-
)
314-
for setting in row_settings:
315-
if setting not in ALL_ROW_SETTINGS:
316-
raise ValueError(f"Row setting {setting} is not a valid setting.")
312+
not_allowed = set(col_names) & REQUIRES_INPUT
313+
if neither_input_specified and not_allowed:
314+
raise ValueError(
315+
"You must pass input_data or input_size in order "
316+
f"to use columns: {not_allowed}"
317+
)
317318

318319

319320
def traverse_input_data(

0 commit comments

Comments
 (0)