Skip to content

Commit 37df6c8

Browse files
committed
Set depth_index at the very end
1 parent 73622ad commit 37df6c8

2 files changed

Lines changed: 11 additions & 3 deletions

File tree

tests/test_output/module_dict.out

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ ModuleDictModel -- --
55
├─ModuleDict: 1-1 -- --
66
│ └─Conv2d: 2-1 [1, 10, 1, 1] 910
77
│ └─MaxPool2d: 2-2 -- --
8-
├─ModuleDict: 1-3 -- --
9-
│ └─LeakyReLU: 2-4 [1, 10, 1, 1] --
10-
│ └─PReLU: 2-5 -- 1
8+
├─ModuleDict: 1-2 -- --
9+
│ └─LeakyReLU: 2-3 [1, 10, 1, 1] --
10+
│ └─PReLU: 2-4 -- 1
1111
==========================================================================================
1212
Total params: 911
1313
Trainable params: 911

torchinfo/torchinfo.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,19 @@ def forward_pass(
289289
summary_list.insert(0, LayerInfo("", model, 0))
290290

291291
add_missing_layers(summary_list, all_layers)
292+
set_depth_index(summary_list)
292293

293294
_cached_forward_pass[model_name] = summary_list
294295
return summary_list
295296

296297

298+
def set_depth_index(summary_list: list[LayerInfo]) -> None:
299+
idx: dict[int, int] = {}
300+
for layer in summary_list:
301+
idx[layer.depth] = idx.get(layer.depth, 0) + 1
302+
layer.depth_index = idx[layer.depth]
303+
304+
297305
def add_missing_layers(
298306
summary_list: list[LayerInfo], all_layers: list[LayerInfo]
299307
) -> None:

0 commit comments

Comments
 (0)