Skip to content

Commit 30df5e8

Browse files
committed
Fix reuse of layer variables bug
1 parent c626a92 commit 30df5e8

5 files changed

Lines changed: 109 additions & 3 deletions

File tree

tests/fixtures/models.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class IdentityModel(nn.Module):
1616

1717
def __init__(self) -> None:
1818
super().__init__()
19-
self.identity = nn.Identity() # type: ignore[no-untyped-call] # noqa
19+
self.identity = nn.Identity() # type: ignore[no-untyped-call]
2020

2121
def forward(self, x: Any) -> Any:
2222
return self.identity(x)
@@ -483,3 +483,43 @@ def __init__(self) -> None:
483483

484484
def forward(self, x: torch.Tensor) -> torch.Tensor:
485485
return self.w * x
486+
487+
488+
class ReuseLinear(nn.Module):
489+
"""Model that uses a reference to the same Linear layer over and over."""
490+
491+
def __init__(self) -> None:
492+
super().__init__()
493+
linear = nn.Linear(10, 10)
494+
model = []
495+
for _ in range(4):
496+
model += [linear, nn.ReLU(True)]
497+
self.model = nn.Sequential(*model)
498+
499+
def forward(self, x: torch.Tensor) -> torch.Tensor:
500+
return cast(torch.Tensor, self.model(x))
501+
502+
503+
class ReuseReLU(nn.Module):
504+
"""Model that uses a reference to the same ReLU layer over and over."""
505+
506+
def __init__(self) -> None:
507+
super().__init__()
508+
activation = nn.ReLU(True)
509+
model = [
510+
nn.ReflectionPad2d(3),
511+
nn.Conv2d(4, 1, kernel_size=1, padding=0),
512+
nn.BatchNorm2d(1), # type: ignore[no-untyped-call]
513+
activation,
514+
]
515+
for i in range(3):
516+
mult = 2 ** i
517+
model += [
518+
nn.Conv2d(mult, mult * 2, kernel_size=1, stride=2, padding=1),
519+
nn.BatchNorm2d(mult * 2), # type: ignore[no-untyped-call]
520+
activation,
521+
]
522+
self.model = nn.Sequential(*model)
523+
524+
def forward(self, x: torch.Tensor) -> torch.Tensor:
525+
return cast(torch.Tensor, self.model(x))
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
==========================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
==========================================================================================
4+
ReuseLinear -- --
5+
├─Sequential: 1-1 [2, 10] --
6+
│ └─Linear: 2-1 [2, 10] 110
7+
│ └─ReLU: 2-2 [2, 10] --
8+
│ └─Linear: 2-3 [2, 10] (recursive)
9+
│ └─ReLU: 2-4 [2, 10] --
10+
│ └─Linear: 2-5 [2, 10] (recursive)
11+
│ └─ReLU: 2-6 [2, 10] --
12+
│ └─Linear: 2-7 [2, 10] (recursive)
13+
│ └─ReLU: 2-8 [2, 10] --
14+
==========================================================================================
15+
Total params: 110
16+
Trainable params: 110
17+
Non-trainable params: 0
18+
Total mult-adds (M): 0.00
19+
==========================================================================================
20+
Input size (MB): 0.00
21+
Forward/backward pass size (MB): 0.00
22+
Params size (MB): 0.00
23+
Estimated Total Size (MB): 0.00
24+
==========================================================================================
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
==========================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
==========================================================================================
4+
ReuseReLU -- --
5+
├─Sequential: 1-1 [4, 8, 11, 11] --
6+
│ └─ReflectionPad2d: 2-1 [4, 4, 70, 70] --
7+
│ └─Conv2d: 2-2 [4, 1, 70, 70] 5
8+
│ └─BatchNorm2d: 2-3 [4, 1, 70, 70] 2
9+
│ └─ReLU: 2-4 [4, 1, 70, 70] --
10+
│ └─Conv2d: 2-5 [4, 2, 36, 36] 4
11+
│ └─BatchNorm2d: 2-6 [4, 2, 36, 36] 4
12+
│ └─ReLU: 2-7 [4, 2, 36, 36] --
13+
│ └─Conv2d: 2-8 [4, 4, 19, 19] 12
14+
│ └─BatchNorm2d: 2-9 [4, 4, 19, 19] 8
15+
│ └─ReLU: 2-10 [4, 4, 19, 19] --
16+
│ └─Conv2d: 2-11 [4, 8, 11, 11] 40
17+
│ └─BatchNorm2d: 2-12 [4, 8, 11, 11] 16
18+
│ └─ReLU: 2-13 [4, 8, 11, 11] --
19+
==========================================================================================
20+
Total params: 91
21+
Trainable params: 91
22+
Non-trainable params: 0
23+
Total mult-adds (M): 0.16
24+
==========================================================================================
25+
Input size (MB): 0.26
26+
Forward/backward pass size (MB): 0.63
27+
Params size (MB): 0.00
28+
Estimated Total Size (MB): 0.90
29+
==========================================================================================

tests/torchinfo_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
PartialJITModel,
2424
RecursiveNet,
2525
ReturnDict,
26+
ReuseLinear,
27+
ReuseReLU,
2628
SiameseNets,
2729
SingleInputNet,
2830
)
@@ -519,3 +521,12 @@ def test_ascii_only() -> None:
519521

520522
def test_google() -> None:
521523
summary(torchvision.models.googlenet(), (1, 3, 112, 112), depth=7)
524+
525+
526+
def test_too_many_linear() -> None:
527+
net = ReuseLinear()
528+
summary(net, (2, 10))
529+
530+
531+
def test_too_many_relus() -> None:
532+
summary(ReuseReLU(), (4, 4, 64, 64))

torchinfo/torchinfo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,10 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None:
511511
if hooks is None or isinstance(module, WRAPPER_MODULES):
512512
pre_hook(module, None)
513513
else:
514-
hooks.append(module.register_forward_pre_hook(pre_hook))
515-
hooks.append(module.register_forward_hook(hook))
514+
if not module._forward_pre_hooks: # pylint: disable=protected-access
515+
hooks.append(module.register_forward_pre_hook(pre_hook))
516+
if not module._forward_hooks: # pylint: disable=protected-access
517+
hooks.append(module.register_forward_hook(hook))
516518

517519
# module.named_modules(remove_duplicate=False) doesn't work (infinite recursion).
518520
for name, mod in module._modules.items(): # pylint: disable=protected-access

0 commit comments

Comments
 (0)