Skip to content

Commit cf8cb6e

Browse files
committed
Rename remaining_params to leftover_params
1 parent d27f902 commit cf8cb6e

3 files changed

Lines changed: 26 additions & 28 deletions

File tree

torchinfo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010
"RowSettings",
1111
"Verbosity",
1212
)
13-
__version__ = "1.6.6"
13+
__version__ = "1.7.0"

torchinfo/layer_info.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -289,23 +289,27 @@ def num_params_to_str(self, reached_max_depth: bool) -> str:
289289
return (
290290
param_count_str if self.trainable_params else f"({param_count_str})"
291291
)
292-
remaining_params = self.remaining_params()
293-
if remaining_params > 0:
294-
return f"{remaining_params:,}"
292+
leftover_params = self.leftover_params()
293+
if leftover_params > 0:
294+
return f"{leftover_params:,}"
295295
return "--"
296296

297-
def remaining_params(self) -> int:
297+
def leftover_params(self) -> int:
298+
"""
299+
Leftover params are the number of params this current layer has that are not
300+
included in the child num_param counts.
301+
"""
298302
return self.num_params - sum(
299-
child.num_params if child.is_leaf_layer else child.remaining_params()
303+
child.num_params if child.is_leaf_layer else child.leftover_params()
300304
for child in self.children
301305
if not child.is_recursive
302306
)
303307

304-
def remaining_trainable_params(self) -> int:
308+
def leftover_trainable_params(self) -> int:
305309
return self.trainable_params - sum(
306310
child.trainable_params
307311
if child.is_leaf_layer
308-
else child.remaining_trainable_params()
312+
else child.leftover_trainable_params()
309313
for child in self.children
310314
if not child.is_recursive
311315
)

torchinfo/model_statistics.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def __init__(
3838
else:
3939
if layer_info.is_recursive:
4040
continue
41-
self.total_params += layer_info.remaining_params()
42-
self.trainable_params += layer_info.remaining_trainable_params()
41+
self.total_params += layer_info.leftover_params()
42+
self.trainable_params += layer_info.leftover_trainable_params()
4343

4444
self.formatting.set_layer_name_width(summary_list)
4545

@@ -55,25 +55,19 @@ def __repr__(self) -> str:
5555
f"Non-trainable params: {self.total_params - self.trainable_params:,}\n"
5656
)
5757
if self.input_size:
58+
unit, macs = self.to_readable(self.total_mult_adds)
59+
input_size = self.to_megabytes(self.total_input)
60+
output_bytes = self.to_megabytes(self.total_output_bytes)
61+
param_bytes = self.to_megabytes(self.total_param_bytes)
62+
total_bytes = self.to_megabytes(
63+
self.total_input + self.total_output_bytes + self.total_param_bytes
64+
)
5865
summary_str += (
59-
"Total mult-adds ({}): {:0.2f}\n{}\n" # pylint: disable=consider-using-f-string # noqa: E501
60-
"Input size (MB): {:0.2f}\n"
61-
"Forward/backward pass size (MB): {:0.2f}\n"
62-
"Params size (MB): {:0.2f}\n"
63-
"Estimated Total Size (MB): {:0.2f}\n".format(
64-
*self.to_readable(self.total_mult_adds),
65-
divider,
66-
self.to_megabytes(self.total_input),
67-
self.to_megabytes(self.total_output_bytes),
68-
self.to_megabytes(self.total_param_bytes),
69-
(
70-
self.to_megabytes(
71-
self.total_input
72-
+ self.total_output_bytes
73-
+ self.total_param_bytes
74-
)
75-
),
76-
)
66+
f"Total mult-adds ({unit}): {macs:0.2f}\n{divider}\n"
67+
f"Input size (MB): {input_size:0.2f}\n"
68+
f"Forward/backward pass size (MB): {output_bytes:0.2f}\n"
69+
f"Params size (MB): {param_bytes:0.2f}\n"
70+
f"Estimated Total Size (MB): {total_bytes:0.2f}\n"
7771
)
7872
summary_str += divider
7973
return summary_str

0 commit comments

Comments
 (0)