Skip to content

Commit 2141b78

Browse files
authored
Add support for pruned models (#103)
* Add support for pruned models According to the [pytorch documentation on pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html), the original parameter is replaced with one ending with `_orig` and a new buffer ending with `_mask`. The mask contains 0s and 1s based on which the correct parameters are chosen. All instances of `param.nelements()` have been replaced by a variable `cur_params` whose value is set based on whether it is a masked model or not. To keep consistency with the rest of the code base, the `_orig` is removed from the `name` variable right after the calculation of `cur_params`. * Add tests for pruning
1 parent d9f4857 commit 2141b78

3 files changed

Lines changed: 67 additions & 7 deletions

File tree

tests/test_output/pruning.out

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
==========================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
==========================================================================================
4+
SingleInputNet -- --
5+
├─Conv2d: 1-1 [16, 10, 24, 24] 135
6+
├─Conv2d: 1-2 [16, 20, 8, 8] 2,520
7+
├─Dropout2d: 1-3 [16, 20, 8, 8] --
8+
├─Linear: 1-4 [16, 50] 8,050
9+
├─Linear: 1-5 [16, 10] 260
10+
==========================================================================================
11+
Total params: 10,965
12+
Trainable params: 10,965
13+
Non-trainable params: 0
14+
Total mult-adds (M): 3.96
15+
==========================================================================================
16+
Input size (MB): 0.05
17+
Forward/backward pass size (MB): 0.91
18+
Params size (MB): 0.04
19+
Estimated Total Size (MB): 1.00
20+
==========================================================================================

tests/torchinfo_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torchvision # type: ignore[import]
55
from torch import nn
6+
from torch.nn.utils import prune
67

78
from tests.conftest import verify_output_str
89
from tests.fixtures.genotype import GenotypeNetwork # type: ignore[attr-defined]
@@ -162,6 +163,19 @@ def test_resnet152() -> None:
162163
summary(model, (1, 3, 224, 224), depth=3)
163164

164165

166+
def test_pruning() -> None:
167+
model = SingleInputNet()
168+
for module in model.modules():
169+
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
170+
prune.l1_unstructured( # type: ignore[no-untyped-call]
171+
module, "weight", 0.5
172+
)
173+
results = summary(model, input_size=(16, 1, 28, 28))
174+
175+
assert results.total_params == 10965
176+
assert results.total_mult_adds == 3957600
177+
178+
165179
def test_dict_input() -> None:
166180
# TODO: expand this test to handle intermediate dict layers.
167181
model = MultipleInputNetDifferentDtypes()

torchinfo/layer_info.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@
1212
]
1313

1414

15+
def rgetattr(obj: torch.nn.Module, attr: str) -> torch.Tensor:
16+
"""Get the tensor submodule called attr from obj."""
17+
for attr_i in attr.split("."):
18+
obj = getattr(obj, attr_i)
19+
20+
if isinstance(obj, torch.Tensor):
21+
return obj
22+
else:
23+
raise AttributeError(f"{attr} is not a tensor")
24+
25+
1526
class LayerInfo:
1627
"""Class that holds information about a layer module."""
1728

@@ -116,16 +127,31 @@ def get_layer_name(self, show_var_name: bool, show_depth: bool) -> str:
116127
layer_name += f"-{self.depth_index}"
117128
return layer_name
118129

130+
def __get_cur_params(self, name: str, param: torch.Tensor) -> tuple[int, str]:
131+
"""
132+
Get count of number of params, accounting for mask
133+
"""
134+
# Masked models save the parameter with the name "_orig" added
135+
# They have a buffer ending with "_mask" which has only 0s and 1s
136+
if name[-4:] == "orig":
137+
# If a mask exists, the sum of 1s in mask is number of params
138+
# Remove "_orig" for better readability and integration
139+
return int(torch.sum(rgetattr(self.module, f"{name[:-4]}mask"))), name[:-5]
140+
else:
141+
return param.nelement(), name
142+
119143
def calculate_num_params(self) -> None:
120144
"""
121145
Set num_params, trainable, inner_layers, and kernel_size
122146
using the module's parameters.
123147
"""
124148
name = ""
125149
for name, param in self.module.named_parameters():
126-
self.num_params += param.nelement()
150+
cur_params, name = self.__get_cur_params(name, param)
151+
152+
self.num_params += cur_params
127153
if param.requires_grad:
128-
self.trainable_params += param.nelement()
154+
self.trainable_params += cur_params
129155

130156
ksize = list(param.size())
131157
if name == "weight":
@@ -137,7 +163,7 @@ def calculate_num_params(self) -> None:
137163
# RNN modules have inner weights such as weight_ih_l0
138164
self.inner_layers[name] = {
139165
"kernel_size": str(ksize),
140-
"num_params": f"├─{param.nelement():,}",
166+
"num_params": f"├─{cur_params:,}",
141167
}
142168
if self.inner_layers:
143169
self.inner_layers[name][
@@ -153,18 +179,18 @@ def calculate_macs(self) -> None:
153179
i.e., taking the batch-dimension into account.
154180
"""
155181
for name, param in self.module.named_parameters():
182+
cur_params, name = self.__get_cur_params(name, param)
156183
if name in ("weight", "bias"):
157184
# ignore C when calculating Mult-Adds in ConvNd
158185
if "Conv" in self.class_name:
159186
self.macs += int(
160-
param.nelement()
161-
* prod(self.output_size[:1] + self.output_size[2:])
187+
cur_params * prod(self.output_size[:1] + self.output_size[2:])
162188
)
163189
else:
164-
self.macs += self.output_size[0] * param.nelement()
190+
self.macs += self.output_size[0] * cur_params
165191
# RNN modules have inner weights such as weight_ih_l0
166192
elif "weight" in name or "bias" in name:
167-
self.macs += prod(self.output_size[:2]) * param.nelement()
193+
self.macs += prod(self.output_size[:2]) * cur_params
168194

169195
def check_recursive(self, summary_list: list[LayerInfo]) -> None:
170196
"""

0 commit comments

Comments
 (0)