2020from torch .jit import ScriptModule
2121from torch .utils .hooks import RemovableHandle
2222
23- from .formatting import (
24- ALL_COLUMN_SETTINGS ,
25- ALL_ROW_SETTINGS ,
26- FormattingOptions ,
27- Verbosity ,
28- )
23+ from .enums import ColumnSettings , RowSettings , Verbosity
24+ from .formatting import FormattingOptions
2925from .layer_info import LayerInfo
3026from .model_statistics import ModelStatistics
3127
4036INPUT_SIZE_TYPE = Sequence [Union [int , Sequence [Any ], torch .Size ]]
4137CORRECTED_INPUT_SIZE_TYPE = List [Union [Sequence [Any ], torch .Size ]]
4238
43- DEFAULT_COLUMN_NAMES = ("output_size" , "num_params" )
44- DEFAULT_ROW_SETTINGS = ("depth" ,)
39+ DEFAULT_COLUMN_NAMES = (ColumnSettings .OUTPUT_SIZE , ColumnSettings .NUM_PARAMS )
40+ DEFAULT_ROW_SETTINGS = {RowSettings .DEPTH }
41+ REQUIRES_INPUT = {
42+ ColumnSettings .INPUT_SIZE ,
43+ ColumnSettings .OUTPUT_SIZE ,
44+ ColumnSettings .MULT_ADDS ,
45+ }
4546
4647_cached_forward_pass : dict [str , list [LayerInfo ]] = {}
4748
@@ -166,12 +167,19 @@ class name as the key. If the forward pass is an expensive operation,
166167 See torchinfo/model_statistics.py for more information.
167168 """
168169 input_data_specified = input_data is not None or input_size is not None
169-
170170 if col_names is None :
171- col_names = DEFAULT_COLUMN_NAMES if input_data_specified else ("num_params" ,)
171+ columns = (
172+ DEFAULT_COLUMN_NAMES
173+ if input_data_specified
174+ else (ColumnSettings .NUM_PARAMS ,)
175+ )
176+ else :
177+ columns = tuple (ColumnSettings (name ) for name in col_names )
172178
173179 if row_settings is None :
174- row_settings = DEFAULT_ROW_SETTINGS
180+ rows = DEFAULT_ROW_SETTINGS
181+ else :
182+ rows = {RowSettings (name ) for name in row_settings }
175183
176184 if verbose is None :
177185 # pylint: disable=no-member
@@ -184,17 +192,15 @@ class name as the key. If the forward pass is an expensive operation,
184192 if device is None :
185193 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
186194
187- validate_user_params (
188- input_data , input_size , col_names , col_width , row_settings , verbose
189- )
195+ validate_user_params (input_data , input_size , columns , col_width , verbose )
190196
191197 x , correct_input_size = process_input (
192198 input_data , input_size , batch_dim , device , dtypes
193199 )
194200 summary_list = forward_pass (
195201 model , x , batch_dim , cache_forward_pass , device , ** kwargs
196202 )
197- formatting = FormattingOptions (depth , verbose , col_names , col_width , row_settings )
203+ formatting = FormattingOptions (depth , verbose , columns , col_width , rows )
198204 results = ModelStatistics (
199205 summary_list , correct_input_size , get_total_memory_used (x ), formatting
200206 )
@@ -287,9 +293,8 @@ def forward_pass(
287293def validate_user_params (
288294 input_data : INPUT_DATA_TYPE | None ,
289295 input_size : INPUT_SIZE_TYPE | None ,
290- col_names : Iterable [ str ],
296+ col_names : tuple [ ColumnSettings , ... ],
291297 col_width : int ,
292- row_settings : Iterable [str ],
293298 verbose : int ,
294299) -> None :
295300 """Raise exceptions if the user's input is invalid."""
@@ -304,16 +309,12 @@ def validate_user_params(
304309 raise RuntimeError ("Only one of (input_data, input_size) should be specified." )
305310
306311 neither_input_specified = input_data is None and input_size is None
307- for col in col_names :
308- if col not in ALL_COLUMN_SETTINGS :
309- raise ValueError (f"Column { col } is not a valid column name." )
310- if neither_input_specified and col not in ("num_params" , "kernel_size" ):
311- raise ValueError (
312- f"You must pass input_data or input_size in order to use column { col } "
313- )
314- for setting in row_settings :
315- if setting not in ALL_ROW_SETTINGS :
316- raise ValueError (f"Row setting { setting } is not a valid setting." )
312+ not_allowed = set (col_names ) & REQUIRES_INPUT
313+ if neither_input_specified and not_allowed :
314+ raise ValueError (
315+ "You must pass input_data or input_size in order "
316+ f"to use columns: { not_allowed } "
317+ )
317318
318319
319320def traverse_input_data (
0 commit comments