Skip to content

Commit cc5b7f6

Browse files
committed
Add community contributions to the README
1 parent 2141b78 commit cc5b7f6

4 files changed

Lines changed: 37 additions & 29 deletions

File tree

.pre-commit-config.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
ci:
2-
skip: [mypy, pylint, pytest]
2+
skip: [mypy, pytest]
33
repos:
44
- repo: https://github.com/asottile/pyupgrade
55
rev: v2.29.1
@@ -34,6 +34,12 @@ repos:
3434
flake8-comprehensions,
3535
]
3636

37+
- repo: https://github.com/PyCQA/pylint
38+
rev: v2.12.2
39+
hooks:
40+
- id: pylint
41+
args: ["--disable=import-error"]
42+
3743
- repo: local
3844
hooks:
3945
- id: mypy
@@ -43,12 +49,6 @@ repos:
4349
types: [python]
4450
require_serial: true
4551

46-
- id: pylint
47-
name: pylint
48-
entry: pylint
49-
language: python
50-
types: [python]
51-
5252
- id: pytest
5353
name: pytest
5454
entry: pytest --cov=torchinfo --cov-report=html --durations=0

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/TylerYep/torchinfo/main.svg)](https://results.pre-commit.ci/latest/github/TylerYep/torchinfo/main)
88
[![GitHub license](https://img.shields.io/github/license/TylerYep/torchinfo)](https://github.com/TylerYep/torchinfo/blob/main/LICENSE)
99
[![codecov](https://codecov.io/gh/TylerYep/torchinfo/branch/main/graph/badge.svg)](https://codecov.io/gh/TylerYep/torchinfo)
10-
[![Downloads](https://pepy.tech/badge/torch-summary)](https://pepy.tech/project/torch-summary)
10+
[![Downloads](https://pepy.tech/badge/torchinfo)](https://pepy.tech/project/torchinfo)
1111

1212
(formerly torch-summary)
1313

@@ -69,7 +69,6 @@ See `tests/jupyter_test.ipynb` for examples.
6969
**This version now supports:**
7070

7171
- RNNs, LSTMs, and other recursive layers
72-
- Sequentials & ModuleLists
7372
- Branching output used to explore model layers using specified depths
7473
- Returns ModelStatistics object containing all summary data fields
7574
- Configurable rows/columns
@@ -82,6 +81,13 @@ See `tests/jupyter_test.ipynb` for examples.
8281
- Customizable line widths and batch dimension
8382
- Comprehensive unit/output testing, linting, and code coverage testing
8483

84+
**Community Contributions:**
85+
86+
- Sequentials & ModuleLists (thanks to @roym899)
87+
- Improved Mult-Add calculations (thanks to @TE-StefanUhlich, @zmzhang2000)
88+
- Dict/Misc input data (thanks to @e-dorigatti)
89+
- Pruned layer support (thanks to @MajorCarrot)
90+
8591
# Documentation
8692

8793
```python

tests/fixtures/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
""" fixtures/models.py """
2+
# pylint: disable=too-few-public-methods
23
from __future__ import annotations
34

45
import math

torchinfo/layer_info.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,12 @@
1212
]
1313

1414

15-
def rgetattr(obj: torch.nn.Module, attr: str) -> torch.Tensor:
16-
"""Get the tensor submodule called attr from obj."""
15+
def rgetattr(module: torch.nn.Module, attr: str) -> torch.Tensor:
16+
"""Get the tensor submodule called attr from module."""
1717
for attr_i in attr.split("."):
18-
obj = getattr(obj, attr_i)
19-
20-
if isinstance(obj, torch.Tensor):
21-
return obj
22-
else:
23-
raise AttributeError(f"{attr} is not a tensor")
18+
module = getattr(module, attr_i)
19+
assert isinstance(module, torch.Tensor)
20+
return module
2421

2522

2623
class LayerInfo:
@@ -127,18 +124,22 @@ def get_layer_name(self, show_var_name: bool, show_depth: bool) -> str:
127124
layer_name += f"-{self.depth_index}"
128125
return layer_name
129126

130-
def __get_cur_params(self, name: str, param: torch.Tensor) -> tuple[int, str]:
127+
def get_param_count(self, name: str, param: torch.Tensor) -> tuple[int, str]:
131128
"""
132-
Get count of number of params, accounting for mask
129+
Get count of number of params, accounting for mask.
130+
131+
Masked models save parameters with the suffix "_orig" added.
132+
They have a buffer ending with "_mask" which has only 0s and 1s.
133+
If a mask exists, the sum of 1s in mask is number of params.
133134
"""
134-
# Masked models save the parameter with the name "_orig" added
135-
# They have a buffer ending with "_mask" which has only 0s and 1s
136-
if name[-4:] == "orig":
137-
# If a mask exists, the sum of 1s in mask is number of params
138-
# Remove "_orig" for better readability and integration
139-
return int(torch.sum(rgetattr(self.module, f"{name[:-4]}mask"))), name[:-5]
140-
else:
141-
return param.nelement(), name
135+
if name.endswith("orig"):
136+
# Remove "_orig" suffix for better readability and integration
137+
without_suffix = name[:-5]
138+
parameter_count = int(
139+
torch.sum(rgetattr(self.module, f"{without_suffix}_mask"))
140+
)
141+
return parameter_count, without_suffix
142+
return param.nelement(), name
142143

143144
def calculate_num_params(self) -> None:
144145
"""
@@ -147,7 +148,7 @@ def calculate_num_params(self) -> None:
147148
"""
148149
name = ""
149150
for name, param in self.module.named_parameters():
150-
cur_params, name = self.__get_cur_params(name, param)
151+
cur_params, name = self.get_param_count(name, param)
151152

152153
self.num_params += cur_params
153154
if param.requires_grad:
@@ -179,7 +180,7 @@ def calculate_macs(self) -> None:
179180
i.e., taking the batch-dimension into account.
180181
"""
181182
for name, param in self.module.named_parameters():
182-
cur_params, name = self.__get_cur_params(name, param)
183+
cur_params, name = self.get_param_count(name, param)
183184
if name in ("weight", "bias"):
184185
# ignore C when calculating Mult-Adds in ConvNd
185186
if "Conv" in self.class_name:

0 commit comments

Comments
 (0)