Skip to content

Commit 051fa20

Browse files
committed
Allow existing hooks to still work with reused layer variables
1 parent 30df5e8 commit 051fa20

4 files changed

Lines changed: 113 additions & 7 deletions

File tree

tests/fixtures/models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
500500
return cast(torch.Tensor, self.model(x))
501501

502502

503+
class ReuseLinearExtended(nn.Module):
504+
"""Model that uses a reference to the same Linear layer over and over."""
505+
506+
def __init__(self) -> None:
507+
super().__init__()
508+
self.linear = nn.Linear(10, 10)
509+
model = []
510+
for _ in range(4):
511+
model += [self.linear, nn.ReLU(True)]
512+
self.model = nn.Sequential(*model)
513+
514+
def forward(self, x: torch.Tensor) -> torch.Tensor:
515+
return cast(torch.Tensor, self.model(x))
516+
517+
503518
class ReuseReLU(nn.Module):
504519
"""Model that uses a reference to the same ReLU layer over and over."""
505520

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
==========================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
==========================================================================================
4+
ReuseLinearExtended -- --
5+
├─Sequential: 1-1 [2, 10] --
6+
├─Linear: 1-2 [2, 10] 110
7+
├─Sequential: 1 -- --
8+
│ └─ReLU: 2-1 [2, 10] --
9+
├─Linear: 1-3 [2, 10] (recursive)
10+
├─Sequential: 1 -- --
11+
│ └─ReLU: 2-2 [2, 10] --
12+
├─Linear: 1-4 [2, 10] (recursive)
13+
├─Sequential: 1 -- --
14+
│ └─ReLU: 2-3 [2, 10] --
15+
├─Linear: 1-5 [2, 10] (recursive)
16+
├─Sequential: 1 -- --
17+
│ └─ReLU: 2-4 [2, 10] --
18+
==========================================================================================
19+
Total params: 110
20+
Trainable params: 110
21+
Non-trainable params: 0
22+
Total mult-adds (M): 0.00
23+
==========================================================================================
24+
Input size (MB): 0.00
25+
Forward/backward pass size (MB): 0.00
26+
Params size (MB): 0.00
27+
Estimated Total Size (MB): 0.00
28+
==========================================================================================
29+
==========================================================================================
30+
Layer (type:depth-idx) Output Shape Param #
31+
==========================================================================================
32+
ReuseLinearExtended -- --
33+
├─Sequential: 1-1 [2, 10] --
34+
├─Linear: 1-2 [2, 10] 110
35+
├─Sequential: 1 -- --
36+
│ └─ReLU: 2-1 [2, 10] --
37+
├─Linear: 1-3 [2, 10] (recursive)
38+
├─Sequential: 1 -- --
39+
│ └─ReLU: 2-2 [2, 10] --
40+
├─Linear: 1-4 [2, 10] (recursive)
41+
├─Sequential: 1 -- --
42+
│ └─ReLU: 2-3 [2, 10] --
43+
├─Linear: 1-5 [2, 10] (recursive)
44+
├─Sequential: 1 -- --
45+
│ └─ReLU: 2-4 [2, 10] --
46+
==========================================================================================
47+
Total params: 110
48+
Trainable params: 110
49+
Non-trainable params: 0
50+
Total mult-adds (M): 0.00
51+
==========================================================================================
52+
Input size (MB): 0.00
53+
Forward/backward pass size (MB): 0.00
54+
Params size (MB): 0.00
55+
Estimated Total Size (MB): 0.00
56+
==========================================================================================

tests/torchinfo_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
import pytest
24
import torch
35
import torchvision # type: ignore[import]
@@ -24,6 +26,7 @@
2426
RecursiveNet,
2527
ReturnDict,
2628
ReuseLinear,
29+
ReuseLinearExtended,
2730
ReuseReLU,
2831
SiameseNets,
2932
SingleInputNet,
@@ -528,5 +531,32 @@ def test_too_many_linear() -> None:
528531
summary(net, (2, 10))
529532

530533

534+
def test_too_many_linear_plus_existing_hooks() -> None:
535+
a, b = False, False
536+
537+
def pre_hook(module: nn.Module, inputs: Any) -> None:
538+
del module, inputs
539+
nonlocal a
540+
a = True
541+
542+
def hook(module: nn.Module, inputs: Any, outputs: Any) -> None:
543+
del module, inputs, outputs
544+
nonlocal b
545+
b = True
546+
547+
net = ReuseLinearExtended()
548+
result_1 = summary(net, (2, 10))
549+
550+
net = ReuseLinearExtended()
551+
net.linear.register_forward_pre_hook(pre_hook)
552+
net.linear.register_forward_hook(hook)
553+
554+
result_2 = summary(net, (2, 10))
555+
556+
assert a is True
557+
assert b is True
558+
assert str(result_1) == str(result_2)
559+
560+
531561
def test_too_many_relus() -> None:
532562
summary(ReuseReLU(), (4, 4, 64, 64))

torchinfo/torchinfo.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ def forward_pass(
251251

252252
all_layers: list[LayerInfo] = []
253253
summary_list: list[LayerInfo] = []
254-
hooks: list[RemovableHandle] | None = None if x is None else []
254+
hooks: dict[int, tuple[RemovableHandle, RemovableHandle]] | None = (
255+
None if x is None else {}
256+
)
255257
named_module = (model_name, model)
256258
apply_hooks(named_module, model, batch_dim, summary_list, hooks, all_layers)
257259

@@ -282,7 +284,8 @@ def forward_pass(
282284
) from e
283285
finally:
284286
if hooks is not None:
285-
for hook in hooks:
287+
for pre_hook, hook in hooks.values():
288+
pre_hook.remove()
286289
hook.remove()
287290
model.train(saved_model_mode)
288291

@@ -473,7 +476,7 @@ def apply_hooks(
473476
orig_model: nn.Module,
474477
batch_dim: int | None,
475478
summary_list: list[LayerInfo],
476-
hooks: list[RemovableHandle] | None,
479+
hooks: dict[int, tuple[RemovableHandle, RemovableHandle]] | None,
477480
all_layers: list[LayerInfo],
478481
curr_depth: int = 0,
479482
parent_info: LayerInfo | None = None,
@@ -511,10 +514,12 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None:
511514
if hooks is None or isinstance(module, WRAPPER_MODULES):
512515
pre_hook(module, None)
513516
else:
514-
if not module._forward_pre_hooks: # pylint: disable=protected-access
515-
hooks.append(module.register_forward_pre_hook(pre_hook))
516-
if not module._forward_hooks: # pylint: disable=protected-access
517-
hooks.append(module.register_forward_hook(hook))
517+
key = id(module)
518+
if key not in hooks:
519+
hooks[key] = (
520+
module.register_forward_pre_hook(pre_hook),
521+
module.register_forward_hook(hook),
522+
)
518523

519524
# module.named_modules(remove_duplicate=False) doesn't work (infinite recursion).
520525
for name, mod in module._modules.items(): # pylint: disable=protected-access

0 commit comments

Comments
 (0)