Skip to content

Commit 7def795

Browse files
bsridattaTylerYep
andauthored
Add "Trainable" column (#128)
* add is_trainable column * test: testcase for is_trainable column * model which has fully, partial and non trainable modules * update tests that require all coloums to display is_trainable coloumn as well * docs: update README.md * fix type ignores and nits * Rename is_trainable to trainable * Calculate trainable in pre_hook * Fix readme Co-authored-by: Tyler Yep <tyler.yep@robinhood.com>
1 parent 793c4f5 commit 7def795

10 files changed

Lines changed: 106 additions & 27 deletions

File tree

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ Summarize the given PyTorch model. Summarized information includes:
115115
2) input/output shapes,
116116
3) kernel shape,
117117
4) # of parameters,
118-
5) # of operations (Mult-Adds)
118+
5) # of operations (Mult-Adds),
119+
6) whether layer is trainable
119120
120121
NOTE: If neither input_data or input_size are provided, no forward pass through the
121122
network is performed, and the provided model information is limited to layer names.
@@ -166,6 +167,7 @@ Args:
166167
"num_params",
167168
"kernel_size",
168169
"mult_adds",
170+
"trainable",
169171
)
170172
Default: ("output_size", "num_params")
171173
If input_data / input_size are not provided, only "num_params" is used.

tests/fixtures/models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
502502
return self.w * x + self.b
503503

504504

505+
class MixedTrainable(nn.Module):
506+
"""Model with fully, partial and non trainable modules."""
507+
508+
def __init__(self) -> None:
509+
super().__init__()
510+
self.fully_trainable = nn.Conv1d(1, 1, 1)
511+
512+
self.partially_trainable = nn.Conv1d(1, 1, 1, bias=True)
513+
assert self.partially_trainable.bias is not None
514+
self.partially_trainable.bias.requires_grad = False
515+
516+
self.non_trainable = nn.Conv1d(1, 1, 1, 1, bias=True)
517+
self.non_trainable.weight.requires_grad = False
518+
assert self.non_trainable.bias is not None
519+
self.non_trainable.bias.requires_grad = False
520+
521+
self.dropout = nn.Dropout()
522+
523+
def forward(self, x: torch.Tensor) -> torch.Tensor:
524+
x = self.fully_trainable(x)
525+
x = self.partially_trainable(x)
526+
x = self.non_trainable(x)
527+
x = self.dropout(x)
528+
return x
529+
530+
505531
class ReuseLinear(nn.Module):
506532
"""Model that uses a reference to the same Linear layer over and over."""
507533

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
===================================================================================================================
2-
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds
3-
===================================================================================================================
4-
ParameterListModel -- -- -- -- --
5-
├─ParameterList: 1-1 -- -- -- 30,000 --
6-
│ └─0 [100, 100] ├─10,000
7-
│ └─1 [100, 200] └─20,000
8-
===================================================================================================================
1+
================================================================================================================================================================
2+
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds Trainable
3+
================================================================================================================================================================
4+
ParameterListModel -- -- -- -- -- True
5+
├─ParameterList: 1-1 -- -- -- 30,000 -- True
6+
│ └─0 [100, 100] ├─10,000
7+
│ └─1 [100, 200] └─20,000
8+
================================================================================================================================================================
99
Total params: 30,000
1010
Trainable params: 30,000
1111
Non-trainable params: 0
1212
Total mult-adds (M): 0.00
13-
===================================================================================================================
13+
================================================================================================================================================================
1414
Input size (MB): 0.04
1515
Forward/backward pass size (MB): 0.00
1616
Params size (MB): 0.12
1717
Estimated Total Size (MB): 0.16
18-
===================================================================================================================
18+
================================================================================================================================================================
Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
============================================================================================================================================
2-
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds
3-
============================================================================================================================================
4-
SingleInputNet -- -- -- -- --
5-
├─Conv2d: 1-1 [5, 5] [7, 1, 28, 28] [7, 10, 24, 24] 260 1,048,320
6-
├─Conv2d: 1-2 [5, 5] [7, 10, 12, 12] [7, 20, 8, 8] 5,020 2,248,960
7-
├─Dropout2d: 1-3 -- [7, 20, 8, 8] [7, 20, 8, 8] -- --
8-
├─Linear: 1-4 -- [7, 320] [7, 50] 16,050 112,350
9-
├─Linear: 1-5 -- [7, 50] [7, 10] 510 3,570
10-
============================================================================================================================================
1+
================================================================================================================================================================
2+
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds Trainable
3+
================================================================================================================================================================
4+
SingleInputNet -- -- -- -- -- True
5+
├─Conv2d: 1-1 [5, 5] [7, 1, 28, 28] [7, 10, 24, 24] 260 1,048,320 True
6+
├─Conv2d: 1-2 [5, 5] [7, 10, 12, 12] [7, 20, 8, 8] 5,020 2,248,960 True
7+
├─Dropout2d: 1-3 -- [7, 20, 8, 8] [7, 20, 8, 8] -- -- --
8+
├─Linear: 1-4 -- [7, 320] [7, 50] 16,050 112,350 True
9+
├─Linear: 1-5 -- [7, 50] [7, 10] 510 3,570 True
10+
================================================================================================================================================================
1111
Total params: 21,840
1212
Trainable params: 21,840
1313
Non-trainable params: 0
1414
Total mult-adds (M): 3.41
15-
============================================================================================================================================
15+
================================================================================================================================================================
1616
Input size (MB): 0.02
1717
Forward/backward pass size (MB): 0.40
1818
Params size (MB): 0.09
1919
Estimated Total Size (MB): 0.51
20-
============================================================================================================================================
20+
================================================================================================================================================================
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
============================================================================================================================================
2+
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Trainable
3+
============================================================================================================================================
4+
MixedTrainable -- -- -- Partial
5+
├─Conv1d: 1-1 [1] [1, 1, 1] [1, 1, 1] True
6+
├─Conv1d: 1-2 [1] [1, 1, 1] [1, 1, 1] Partial
7+
├─Conv1d: 1-3 [1] [1, 1, 1] [1, 1, 1] False
8+
├─Dropout: 1-4 -- [1, 1, 1] [1, 1, 1] --
9+
============================================================================================================================================
10+
Total params: 6
11+
Trainable params: 3
12+
Non-trainable params: 3
13+
Total mult-adds (M): 0.00
14+
============================================================================================================================================
15+
Input size (MB): 0.00
16+
Forward/backward pass size (MB): 0.00
17+
Params size (MB): 0.00
18+
Estimated Total Size (MB): 0.00
19+
============================================================================================================================================

tests/torchinfo_test.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
FakePrunedLayerModel,
1515
LinearModel,
1616
LSTMNet,
17+
MixedTrainable,
1718
MixedTrainableParameters,
1819
ModuleDictModel,
1920
MultipleInputNetDifferentDtypes,
@@ -111,13 +112,12 @@ def test_multiple_input_types() -> None:
111112

112113
def test_single_input_all_cols() -> None:
113114
model = SingleInputNet()
114-
col_names = ("kernel_size", "input_size", "output_size", "num_params", "mult_adds")
115115
input_shape = (7, 1, 28, 28)
116116
summary(
117117
model,
118118
input_data=torch.randn(*input_shape),
119119
depth=1,
120-
col_names=col_names,
120+
col_names=list(ColumnSettings),
121121
col_width=20,
122122
)
123123

@@ -194,7 +194,7 @@ def test_parameter_list() -> None:
194194
input_size=(100, 100),
195195
verbose=2,
196196
col_names=list(ColumnSettings),
197-
col_width=15,
197+
col_width=20,
198198
)
199199

200200

@@ -462,3 +462,11 @@ def test_pruned_adversary() -> None:
462462
results = summary(second_model, input_size=(1,))
463463

464464
assert results.total_params == 32 # should be 64
465+
466+
467+
def test_trainable_column() -> None:
468+
summary(
469+
MixedTrainable(),
470+
input_size=(1, 1, 1),
471+
col_names=("kernel_size", "input_size", "output_size", "trainable"),
472+
)

torchinfo/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ColumnSettings(str, Enum):
2929
OUTPUT_SIZE = "output_size"
3030
NUM_PARAMS = "num_params"
3131
MULT_ADDS = "mult_adds"
32+
TRAINABLE = "trainable"
3233

3334

3435
@unique

torchinfo/formatting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ColumnSettings.OUTPUT_SIZE: "Output Shape",
1313
ColumnSettings.NUM_PARAMS: "Param #",
1414
ColumnSettings.MULT_ADDS: "Mult-Adds",
15+
ColumnSettings.TRAINABLE: "Trainable",
1516
}
1617

1718

@@ -113,6 +114,7 @@ def layer_info_to_row(
113114
ColumnSettings.MULT_ADDS: layer_info.macs_to_str(
114115
reached_max_depth, children_layers
115116
),
117+
ColumnSettings.TRAINABLE: self.str_(layer_info.trainable),
116118
}
117119
start_str = self.get_start_str(layer_info.depth)
118120
layer_name = layer_info.get_layer_name(self.show_var_name, self.show_depth)

torchinfo/layer_info.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
self.param_bytes = 0
6060
self.output_bytes = 0
6161
self.macs = 0
62+
self.trainable = self.is_trainable(module)
6263

6364
def __repr__(self) -> str:
6465
return f"{self.class_name}: {self.depth}"
@@ -159,6 +160,24 @@ def get_kernel_size(module: nn.Module) -> int | list[int] | None:
159160
return kernel_size
160161
return None
161162

163+
@staticmethod
164+
def is_trainable(module: nn.Module) -> str:
165+
"""
166+
Checks if the module is trainable. Returns:
167+
"True", if all the parameters are trainable (`requires_grad=True`)
168+
"False" if none of the parameters are trainable.
169+
"Partial" if some weights are trainable, but not all.
170+
"--" if no module has no parameters, like Dropout.
171+
"""
172+
module_requires_grad = [param.requires_grad for param in module.parameters()]
173+
if not module_requires_grad:
174+
return "--"
175+
if all(module_requires_grad):
176+
return "True"
177+
if any(module_requires_grad):
178+
return "Partial"
179+
return "False"
180+
162181
def get_layer_name(self, show_var_name: bool, show_depth: bool) -> str:
163182
layer_name = self.class_name
164183
if show_var_name and self.var_name:

torchinfo/torchinfo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def summary(
7070
2) input/output shapes,
7171
3) kernel shape,
7272
4) # of parameters,
73-
5) # of operations (Mult-Adds)
73+
5) # of operations (Mult-Adds),
74+
6) whether layer is trainable
7475
7576
NOTE: If neither input_data or input_size are provided, no forward pass through the
7677
network is performed, and the provided model information is limited to layer names.
@@ -121,6 +122,7 @@ class name as the key. If the forward pass is an expensive operation,
121122
"num_params",
122123
"kernel_size",
123124
"mult_adds",
125+
"trainable",
124126
)
125127
Default: ("output_size", "num_params")
126128
If input_data / input_size are not provided, only "num_params" is used.

0 commit comments

Comments
 (0)