Skip to content

Commit 2b82275

Browse files
committed
Add ascii_only RowSettings
1 parent 9d62d3c commit 2b82275

6 files changed

Lines changed: 109 additions & 8 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ Args:
190190
191191
row_settings (Iterable[str]):
192192
Specify which features to show in a row. Currently supported: (
193+
"ascii_only",
193194
"depth",
194195
"var_names",
195196
)

tests/test_output/ascii_only.out

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
==========================================================================================
2+
Layer (type) Output Shape Param #
3+
==========================================================================================
4+
ResNet -- --
5+
+ Conv2d [1, 64, 32, 32] 9,408
6+
+ BatchNorm2d [1, 64, 32, 32] 128
7+
+ ReLU [1, 64, 32, 32] --
8+
+ MaxPool2d [1, 64, 16, 16] --
9+
+ Sequential [1, 64, 16, 16] --
10+
| + BasicBlock [1, 64, 16, 16] --
11+
| | + Conv2d [1, 64, 16, 16] 36,864
12+
| | + BatchNorm2d [1, 64, 16, 16] 128
13+
| | + ReLU [1, 64, 16, 16] --
14+
| | + Conv2d [1, 64, 16, 16] 36,864
15+
| | + BatchNorm2d [1, 64, 16, 16] 128
16+
| | + ReLU [1, 64, 16, 16] --
17+
| + BasicBlock [1, 64, 16, 16] --
18+
| | + Conv2d [1, 64, 16, 16] 36,864
19+
| | + BatchNorm2d [1, 64, 16, 16] 128
20+
| | + ReLU [1, 64, 16, 16] --
21+
| | + Conv2d [1, 64, 16, 16] 36,864
22+
| | + BatchNorm2d [1, 64, 16, 16] 128
23+
| | + ReLU [1, 64, 16, 16] --
24+
+ Sequential [1, 128, 8, 8] --
25+
| + BasicBlock [1, 128, 8, 8] --
26+
| | + Conv2d [1, 128, 8, 8] 73,728
27+
| | + BatchNorm2d [1, 128, 8, 8] 256
28+
| | + ReLU [1, 128, 8, 8] --
29+
| | + Conv2d [1, 128, 8, 8] 147,456
30+
| | + BatchNorm2d [1, 128, 8, 8] 256
31+
| | + Sequential [1, 128, 8, 8] 8,448
32+
| | + ReLU [1, 128, 8, 8] --
33+
| + BasicBlock [1, 128, 8, 8] --
34+
| | + Conv2d [1, 128, 8, 8] 147,456
35+
| | + BatchNorm2d [1, 128, 8, 8] 256
36+
| | + ReLU [1, 128, 8, 8] --
37+
| | + Conv2d [1, 128, 8, 8] 147,456
38+
| | + BatchNorm2d [1, 128, 8, 8] 256
39+
| | + ReLU [1, 128, 8, 8] --
40+
+ Sequential [1, 256, 4, 4] --
41+
| + BasicBlock [1, 256, 4, 4] --
42+
| | + Conv2d [1, 256, 4, 4] 294,912
43+
| | + BatchNorm2d [1, 256, 4, 4] 512
44+
| | + ReLU [1, 256, 4, 4] --
45+
| | + Conv2d [1, 256, 4, 4] 589,824
46+
| | + BatchNorm2d [1, 256, 4, 4] 512
47+
| | + Sequential [1, 256, 4, 4] 33,280
48+
| | + ReLU [1, 256, 4, 4] --
49+
| + BasicBlock [1, 256, 4, 4] --
50+
| | + Conv2d [1, 256, 4, 4] 589,824
51+
| | + BatchNorm2d [1, 256, 4, 4] 512
52+
| | + ReLU [1, 256, 4, 4] --
53+
| | + Conv2d [1, 256, 4, 4] 589,824
54+
| | + BatchNorm2d [1, 256, 4, 4] 512
55+
| | + ReLU [1, 256, 4, 4] --
56+
+ Sequential [1, 512, 2, 2] --
57+
| + BasicBlock [1, 512, 2, 2] --
58+
| | + Conv2d [1, 512, 2, 2] 1,179,648
59+
| | + BatchNorm2d [1, 512, 2, 2] 1,024
60+
| | + ReLU [1, 512, 2, 2] --
61+
| | + Conv2d [1, 512, 2, 2] 2,359,296
62+
| | + BatchNorm2d [1, 512, 2, 2] 1,024
63+
| | + Sequential [1, 512, 2, 2] 132,096
64+
| | + ReLU [1, 512, 2, 2] --
65+
| + BasicBlock [1, 512, 2, 2] --
66+
| | + Conv2d [1, 512, 2, 2] 2,359,296
67+
| | + BatchNorm2d [1, 512, 2, 2] 1,024
68+
| | + ReLU [1, 512, 2, 2] --
69+
| | + Conv2d [1, 512, 2, 2] 2,359,296
70+
| | + BatchNorm2d [1, 512, 2, 2] 1,024
71+
| | + ReLU [1, 512, 2, 2] --
72+
+ AdaptiveAvgPool2d [1, 512, 1, 1] --
73+
+ Linear [1, 1000] 513,000
74+
==========================================================================================
75+
Total params: 11,689,512
76+
Trainable params: 11,689,512
77+
Non-trainable params: 0
78+
Total mult-adds (M): 148.57
79+
==========================================================================================
80+
Input size (MB): 0.05
81+
Forward/backward pass size (MB): 3.25
82+
Params size (MB): 46.76
83+
Estimated Total Size (MB): 50.06
84+
==========================================================================================

tests/torchinfo_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,3 +504,14 @@ def test_mixed_trainable_parameters() -> None:
504504

505505
assert result.trainable_params == 10
506506
assert result.total_params == 20
507+
508+
509+
def test_ascii_only() -> None:
510+
result = summary(
511+
torchvision.models.resnet18(),
512+
depth=3,
513+
input_size=(1, 3, 64, 64),
514+
row_settings=["ascii_only"],
515+
)
516+
517+
assert str(result).encode("ascii").decode("ascii")

torchinfo/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class RowSettings(Enum):
1010

1111
DEPTH = "depth"
1212
VAR_NAMES = "var_names"
13+
ASCII_ONLY = "ascii_only"
1314

1415

1516
@unique

torchinfo/formatting.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,10 @@ def __init__(
3434
self.row_settings = row_settings
3535

3636
self.layer_name_width = 40
37+
self.ascii_only = RowSettings.ASCII_ONLY in self.row_settings
3738
self.show_var_name = RowSettings.VAR_NAMES in self.row_settings
3839
self.show_depth = RowSettings.DEPTH in self.row_settings
3940

40-
@staticmethod
41-
def get_start_str(depth: int) -> str:
42-
if depth == 0:
43-
return ""
44-
if depth == 1:
45-
return "├─"
46-
return "│ " * (depth - 1) + "└─"
47-
4841
@staticmethod
4942
def str_(val: Any) -> str:
5043
return str(val) if val else "--"
@@ -61,6 +54,16 @@ def get_children_layers(
6154
num_children += 1
6255
return summary_list[index + 1 : index + 1 + num_children]
6356

57+
def get_start_str(self, depth: int) -> str:
58+
"""This function should handle all ascii/non-ascii-related characters."""
59+
if depth == 0:
60+
return ""
61+
if depth == 1:
62+
return "+ " if self.ascii_only else "├─"
63+
return ("| " if self.ascii_only else "│ ") * (depth - 1) + (
64+
"+ " if self.ascii_only else "└─"
65+
)
66+
6467
def set_layer_name_width(
6568
self, summary_list: list[LayerInfo], align_val: int = 5
6669
) -> None:

torchinfo/torchinfo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class name as the key. If the forward pass is an expensive operation,
146146
147147
row_settings (Iterable[str]):
148148
Specify which features to show in a row. Currently supported: (
149+
"ascii_only",
149150
"depth",
150151
"var_names",
151152
)

0 commit comments

Comments
 (0)