Skip to content

Commit c626a92

Browse files
committed
Remove references to depth_index in apply_hooks
1 parent 37df6c8 commit c626a92

2 files changed

Lines changed: 6 additions & 28 deletions

File tree

torchinfo/layer_info.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def __init__(
2929
var_name: str,
3030
module: nn.Module,
3131
depth: int,
32-
depth_index: int | None = None,
3332
parent_info: LayerInfo | None = None,
3433
) -> None:
3534
# Identifying information
@@ -43,7 +42,7 @@ def __init__(
4342
# {layer name: {col_name: value_for_row}}
4443
self.inner_layers: dict[str, dict[ColumnSettings, Any]] = {}
4544
self.depth = depth
46-
self.depth_index = depth_index
45+
self.depth_index: int | None = None # set at the very end
4746
self.executed = False
4847
self.parent_info = parent_info
4948
self.var_name = var_name

torchinfo/torchinfo.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,12 @@ def forward_pass(
253253
summary_list: list[LayerInfo] = []
254254
hooks: list[RemovableHandle] | None = None if x is None else []
255255
named_module = (model_name, model)
256-
apply_hooks(named_module, model, batch_dim, summary_list, {}, hooks, all_layers)
256+
apply_hooks(named_module, model, batch_dim, summary_list, hooks, all_layers)
257257

258258
if x is None:
259259
if not summary_list or summary_list[0].var_name != model_name:
260260
summary_list.insert(0, LayerInfo("", model, 0))
261+
set_depth_index(summary_list)
261262
return summary_list
262263

263264
kwargs = set_device(kwargs, device)
@@ -318,33 +319,14 @@ def add_missing_layers(
318319
if not set(b) - set(a):
319320
return
320321

321-
for tag, _, i2, j1, j2 in difflib.SequenceMatcher(None, a, b).get_opcodes():
322+
for tag, _, _, j1, j2 in difflib.SequenceMatcher(None, a, b).get_opcodes():
322323
# Ignore all other layer types besides "insert".
323324
if tag == "insert":
324-
depth_shifts: dict[int, int] = {}
325325
for i, info in enumerate(all_layers[j1:j2]):
326-
# Find the correct depth_index for the new layer.
327-
info.depth_index = 1
328-
for prev_layer in reversed(summary_list[: j1 + i]):
329-
if (
330-
prev_layer.depth_index is not None
331-
and prev_layer.depth == info.depth
332-
):
333-
info.depth_index = prev_layer.depth_index + 1
334-
break
335-
if info.depth not in depth_shifts:
336-
depth_shifts[info.depth] = 0
337-
depth_shifts[info.depth] += 1
338-
339326
info.calculate_num_params()
340327
info.check_recursive(summary_list)
341328
summary_list.insert(j1 + i, info)
342329

343-
# Shift depths forward for all existing later layers
344-
for info in summary_list[i2 + (j2 - j1) :]:
345-
if info.depth_index is not None:
346-
info.depth_index += depth_shifts.get(info.depth, 0)
347-
348330

349331
def validate_user_params(
350332
input_data: INPUT_DATA_TYPE | None,
@@ -491,7 +473,6 @@ def apply_hooks(
491473
orig_model: nn.Module,
492474
batch_dim: int | None,
493475
summary_list: list[LayerInfo],
494-
idx: dict[int, int],
495476
hooks: list[RemovableHandle] | None,
496477
all_layers: list[LayerInfo],
497478
curr_depth: int = 0,
@@ -505,15 +486,14 @@ def apply_hooks(
505486
# Fallback is used if the layer's pre-hook is never called, for example in
506487
# ModuleLists or Sequentials.
507488
var_name, module = named_module
508-
info = LayerInfo(var_name, module, curr_depth, None, parent_info)
489+
info = LayerInfo(var_name, module, curr_depth, parent_info)
509490
all_layers.append(info)
510491

511492
def pre_hook(module: nn.Module, inputs: Any) -> None:
512493
"""Create a LayerInfo object to aggregate information about that layer."""
513494
del inputs
514495
nonlocal info
515-
idx[curr_depth] = idx.get(curr_depth, 0) + 1
516-
info = LayerInfo(var_name, module, curr_depth, idx[curr_depth], parent_info)
496+
info = LayerInfo(var_name, module, curr_depth, parent_info)
517497
info.calculate_num_params()
518498
info.check_recursive(summary_list)
519499
summary_list.append(info)
@@ -543,7 +523,6 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None:
543523
orig_model,
544524
batch_dim,
545525
summary_list,
546-
idx,
547526
hooks,
548527
all_layers,
549528
curr_depth + 1,

0 commit comments

Comments
 (0)