Skip to content

Commit 6dd06db

Browse files
committed
Move minimum version to Python 3.7
1 parent 0a69f3f commit 6dd06db

8 files changed

Lines changed: 85 additions & 70 deletions

File tree

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ repos:
2727
rev: 4.0.1
2828
hooks:
2929
- id: flake8
30+
additional_dependencies:
31+
[
32+
flake8-future-annotations,
33+
flake8-bugbear,
34+
flake8-comprehensions,
35+
]
3036

3137
- repo: local
3238
hooks:

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# torchinfo
22

3-
[![Python 3.6+](https://img.shields.io/badge/python-3.6+-blue.svg)](https://www.python.org/downloads/release/python-360/)
3+
[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-360/)
44
[![PyPI version](https://badge.fury.io/py/torchinfo.svg)](https://badge.fury.io/py/torchinfo)
55
[![Conda version](https://img.shields.io/conda/vn/conda-forge/torchinfo)](https://anaconda.org/conda-forge/torchinfo)
66
[![Build Status](https://github.com/TylerYep/torchinfo/actions/workflows/test.yml/badge.svg)](https://github.com/TylerYep/torchinfo/actions/workflows/test.yml)
@@ -22,6 +22,7 @@ pip install torchinfo
2222
```
2323

2424
Alternatively, via conda:
25+
2526
```
2627
conda install -c conda-forge torchinfo
2728
```
@@ -452,7 +453,7 @@ Estimated Total Size (MB): 0.00
452453
All issues and pull requests are much appreciated! If you are wondering how to build the project:
453454

454455
- torchinfo is actively developed using the lastest version of Python.
455-
- Changes should be backward compatible with Python 3.6, but this is subject to change in the future.
456+
- Changes should be backward compatible to Python 3.7, and will follow Python's End-of-Life guidance for old versions.
456457
- Run `pip install -r requirements-dev.txt`. We use the latest versions of all dev packages.
457458
- Run `pre-commit install`.
458459
- To use auto-formatting tools, use `pre-commit run -a`.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ keywords = torch pytorch torchsummary torch-summary summary keras deep-learning
2323

2424
[options]
2525
packages = torchinfo
26-
python_requires = >=3.6
26+
python_requires = >=3.7
2727
include_package_data = True
2828

2929
[options.package_data]

tests/fixtures/models.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
""" fixtures/models.py """
2+
from __future__ import annotations
3+
24
import math
35
from collections import namedtuple
4-
from typing import Any, Dict, Tuple, cast
6+
from typing import Any, cast
57

68
import torch
79
from torch import nn
@@ -113,7 +115,7 @@ def __init__(
113115
self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers) # type: ignore[no-untyped-call] # noqa
114116
self.decoder = nn.Linear(hidden_dim, vocab_size)
115117

116-
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
118+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
117119
embed = self.embedding(x)
118120
out, hidden = self.encoder(embed)
119121
out = self.decoder(out)
@@ -244,7 +246,7 @@ def __init__(self) -> None:
244246
self.fc1 = nn.Linear(320, 50)
245247
self.fc2 = nn.Linear(50, 10)
246248

247-
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
249+
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
248250
activation_dict = {}
249251
x = self.conv1(x)
250252
activation_dict["conv1"] = x
@@ -269,9 +271,9 @@ def __init__(self) -> None:
269271
super().__init__()
270272
self.return_dict = ReturnDictLayer()
271273

272-
def forward(self, x: torch.Tensor, y: Any) -> Dict[str, torch.Tensor]:
274+
def forward(self, x: torch.Tensor, y: Any) -> dict[str, torch.Tensor]:
273275
del y
274-
activation_dict: Dict[str, torch.Tensor] = self.return_dict(x)
276+
activation_dict: dict[str, torch.Tensor] = self.return_dict(x)
275277
return activation_dict
276278

277279

@@ -282,7 +284,7 @@ def __init__(self) -> None:
282284
super().__init__()
283285
self.constant = 5
284286

285-
def forward(self, x: Dict[int, torch.Tensor], scale_factor: int) -> torch.Tensor:
287+
def forward(self, x: dict[int, torch.Tensor], scale_factor: int) -> torch.Tensor:
286288
return scale_factor * (x[256] + x[512][0]) * self.constant
287289

288290

@@ -427,7 +429,7 @@ def __init__(self) -> None:
427429
self.parameter = torch.rand(3, 3, requires_grad=True)
428430
self.example_input_array = torch.zeros(1, 2, 3, 4, 5)
429431

430-
def forward(self) -> Dict[str, Any]:
432+
def forward(self) -> dict[str, Any]:
431433
return {"loss": self.parameter.sum()}
432434

433435

torchinfo/formatting.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
""" formatting.py """
2+
from __future__ import annotations
3+
24
import math
35
from enum import Enum, unique
4-
from typing import Any, Dict, Iterable, List
6+
from typing import Any, Iterable
57

68
from .layer_info import LayerInfo
79

@@ -64,8 +66,8 @@ def str_(val: Any) -> str:
6466

6567
@staticmethod
6668
def get_children_layers(
67-
summary_list: List[LayerInfo], layer_info: LayerInfo, index: int
68-
) -> List[LayerInfo]:
69+
summary_list: list[LayerInfo], layer_info: LayerInfo, index: int
70+
) -> list[LayerInfo]:
6971
"""Fetches all of the children of a given layer."""
7072
num_children = 0
7173
for layer in summary_list[index + 1 :]:
@@ -75,7 +77,7 @@ def get_children_layers(
7577
return summary_list[index + 1 : index + 1 + num_children]
7678

7779
def set_layer_name_width(
78-
self, summary_list: List[LayerInfo], align_val: int = 5
80+
self, summary_list: list[LayerInfo], align_val: int = 5
7981
) -> None:
8082
"""
8183
Set layer name width by taking the longest line length and rounding up to
@@ -93,7 +95,7 @@ def get_total_width(self) -> int:
9395
"""Calculate the total width of all lines in the table."""
9496
return len(tuple(self.col_names)) * self.col_width + self.layer_name_width
9597

96-
def format_row(self, layer_name: str, row_values: Dict[str, str]) -> str:
98+
def format_row(self, layer_name: str, row_values: dict[str, str]) -> str:
9799
"""Get the string representation of a single layer of the model."""
98100
info_to_use = [row_values.get(row_type, "") for row_type in self.col_names]
99101
new_line = f"{layer_name:<{self.layer_name_width}} "
@@ -113,7 +115,7 @@ def layer_info_to_row(
113115
self,
114116
layer_info: LayerInfo,
115117
reached_max_depth: bool,
116-
children_layers: List["LayerInfo"],
118+
children_layers: list[LayerInfo],
117119
) -> str:
118120
"""Convert layer_info to string representation of a row."""
119121
row_values = {
@@ -133,13 +135,13 @@ def layer_info_to_row(
133135
new_line += self.format_row(f"{prefix}{inner_name}", inner_layer_info)
134136
return new_line
135137

136-
def layers_to_str(self, summary_list: List[LayerInfo]) -> str:
138+
def layers_to_str(self, summary_list: list[LayerInfo]) -> str:
137139
"""
138140
Print each layer of the model using a fancy branching diagram.
139141
This is necessary to handle Container modules that don't have explicit parents.
140142
"""
141143
new_str = ""
142-
current_hierarchy: Dict[int, LayerInfo] = {}
144+
current_hierarchy: dict[int, LayerInfo] = {}
143145
for i, layer_info in enumerate(summary_list):
144146
if layer_info.depth > self.max_depth:
145147
continue

torchinfo/layer_info.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
""" layer_info.py """
2-
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
2+
from __future__ import annotations
3+
4+
from typing import Any, Dict, Iterable, Sequence, Union
35

46
import torch
57
from torch import nn
@@ -18,8 +20,8 @@ def __init__(
1820
var_name: str,
1921
module: nn.Module,
2022
depth: int,
21-
depth_index: Optional[int] = None,
22-
parent_info: Optional["LayerInfo"] = None,
23+
depth_index: int | None = None,
24+
parent_info: LayerInfo | None = None,
2325
) -> None:
2426
# Identifying information
2527
self.layer_id = id(module)
@@ -30,7 +32,7 @@ def __init__(
3032
else module.__class__.__name__
3133
)
3234
# {layer name: {row_name: row_value}}
33-
self.inner_layers: Dict[str, Dict[str, Any]] = {}
35+
self.inner_layers: dict[str, dict[str, Any]] = {}
3436
self.depth = depth
3537
self.depth_index = depth_index
3638
self.executed = False
@@ -41,9 +43,9 @@ def __init__(
4143
# Statistics
4244
self.trainable_params = 0
4345
self.is_recursive = False
44-
self.input_size: List[int] = []
45-
self.output_size: List[int] = []
46-
self.kernel_size: List[int] = []
46+
self.input_size: list[int] = []
47+
self.output_size: list[int] = []
48+
self.kernel_size: list[int] = []
4749
self.num_params = 0
4850
self.macs = 0
4951

@@ -52,11 +54,11 @@ def __repr__(self) -> str:
5254

5355
@staticmethod
5456
def calculate_size(
55-
inputs: DETECTED_INPUT_OUTPUT_TYPES, batch_dim: Optional[int]
56-
) -> List[int]:
57+
inputs: DETECTED_INPUT_OUTPUT_TYPES, batch_dim: int | None
58+
) -> list[int]:
5759
"""Set input_size or output_size using the model's inputs."""
5860

59-
def nested_list_size(inputs: Sequence[Any]) -> List[int]:
61+
def nested_list_size(inputs: Sequence[Any]) -> list[int]:
6062
"""Flattens nested list size."""
6163
if hasattr(inputs, "tensors"):
6264
return nested_list_size(inputs.tensors) # type: ignore[attr-defined]
@@ -164,7 +166,7 @@ def calculate_macs(self) -> None:
164166
elif "weight" in name or "bias" in name:
165167
self.macs += prod(self.output_size[:2]) * param.nelement()
166168

167-
def check_recursive(self, summary_list: List["LayerInfo"]) -> None:
169+
def check_recursive(self, summary_list: list[LayerInfo]) -> None:
168170
"""
169171
If the current module is already-used, mark as (recursive).
170172
Must check before adding line to the summary.
@@ -175,7 +177,7 @@ def check_recursive(self, summary_list: List["LayerInfo"]) -> None:
175177
self.is_recursive = True
176178

177179
def macs_to_str(
178-
self, reached_max_depth: bool, children_layers: List["LayerInfo"]
180+
self, reached_max_depth: bool, children_layers: list[LayerInfo]
179181
) -> str:
180182
"""Convert MACs to string."""
181183
if self.macs <= 0:
@@ -199,7 +201,7 @@ def num_params_to_str(self, reached_max_depth: bool) -> str:
199201
return "--"
200202

201203

202-
def prod(num_list: Union[Iterable[int], torch.Size]) -> int:
204+
def prod(num_list: Iterable[int] | torch.Size) -> int:
203205
result = 1
204206
if isinstance(num_list, Iterable):
205207
for item in num_list:

torchinfo/model_statistics.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
""" model_statistics.py """
2-
from typing import Any, List, Tuple
2+
from __future__ import annotations
3+
4+
from typing import Any
35

46
from .formatting import FormattingOptions
57
from .layer_info import LayerInfo, prod
@@ -10,7 +12,7 @@ class ModelStatistics:
1012

1113
def __init__(
1214
self,
13-
summary_list: List[LayerInfo],
15+
summary_list: list[LayerInfo],
1416
input_size: Any,
1517
total_input_size: int,
1618
formatting: FormattingOptions,
@@ -78,7 +80,7 @@ def to_megabytes(num: int) -> float:
7880
return num / 1e6
7981

8082
@staticmethod
81-
def to_readable(num: int) -> Tuple[str, float]:
83+
def to_readable(num: int) -> tuple[str, float]:
8284
"""Converts a number to millions, billions, or trillions."""
8385
if num >= 1e12:
8486
return "T", num / 1e12

0 commit comments

Comments
 (0)