@@ -253,11 +253,12 @@ def forward_pass(
253253 summary_list : list [LayerInfo ] = []
254254 hooks : list [RemovableHandle ] | None = None if x is None else []
255255 named_module = (model_name , model )
256- apply_hooks (named_module , model , batch_dim , summary_list , {}, hooks , all_layers )
256+ apply_hooks (named_module , model , batch_dim , summary_list , hooks , all_layers )
257257
258258 if x is None :
259259 if not summary_list or summary_list [0 ].var_name != model_name :
260260 summary_list .insert (0 , LayerInfo ("" , model , 0 ))
261+ set_depth_index (summary_list )
261262 return summary_list
262263
263264 kwargs = set_device (kwargs , device )
@@ -318,33 +319,14 @@ def add_missing_layers(
318319 if not set (b ) - set (a ):
319320 return
320321
321- for tag , _ , i2 , j1 , j2 in difflib .SequenceMatcher (None , a , b ).get_opcodes ():
322+ for tag , _ , _ , j1 , j2 in difflib .SequenceMatcher (None , a , b ).get_opcodes ():
322323 # Ignore all other layer types besides "insert".
323324 if tag == "insert" :
324- depth_shifts : dict [int , int ] = {}
325325 for i , info in enumerate (all_layers [j1 :j2 ]):
326- # Find the correct depth_index for the new layer.
327- info .depth_index = 1
328- for prev_layer in reversed (summary_list [: j1 + i ]):
329- if (
330- prev_layer .depth_index is not None
331- and prev_layer .depth == info .depth
332- ):
333- info .depth_index = prev_layer .depth_index + 1
334- break
335- if info .depth not in depth_shifts :
336- depth_shifts [info .depth ] = 0
337- depth_shifts [info .depth ] += 1
338-
339326 info .calculate_num_params ()
340327 info .check_recursive (summary_list )
341328 summary_list .insert (j1 + i , info )
342329
343- # Shift depths forward for all existing later layers
344- for info in summary_list [i2 + (j2 - j1 ) :]:
345- if info .depth_index is not None :
346- info .depth_index += depth_shifts .get (info .depth , 0 )
347-
348330
349331def validate_user_params (
350332 input_data : INPUT_DATA_TYPE | None ,
@@ -491,7 +473,6 @@ def apply_hooks(
491473 orig_model : nn .Module ,
492474 batch_dim : int | None ,
493475 summary_list : list [LayerInfo ],
494- idx : dict [int , int ],
495476 hooks : list [RemovableHandle ] | None ,
496477 all_layers : list [LayerInfo ],
497478 curr_depth : int = 0 ,
@@ -505,15 +486,14 @@ def apply_hooks(
505486 # Fallback is used if the layer's pre-hook is never called, for example in
506487 # ModuleLists or Sequentials.
507488 var_name , module = named_module
508- info = LayerInfo (var_name , module , curr_depth , None , parent_info )
489+ info = LayerInfo (var_name , module , curr_depth , parent_info )
509490 all_layers .append (info )
510491
511492 def pre_hook (module : nn .Module , inputs : Any ) -> None :
512493 """Create a LayerInfo object to aggregate information about that layer."""
513494 del inputs
514495 nonlocal info
515- idx [curr_depth ] = idx .get (curr_depth , 0 ) + 1
516- info = LayerInfo (var_name , module , curr_depth , idx [curr_depth ], parent_info )
496+ info = LayerInfo (var_name , module , curr_depth , parent_info )
517497 info .calculate_num_params ()
518498 info .check_recursive (summary_list )
519499 summary_list .append (info )
@@ -543,7 +523,6 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None:
543523 orig_model ,
544524 batch_dim ,
545525 summary_list ,
546- idx ,
547526 hooks ,
548527 all_layers ,
549528 curr_depth + 1 ,
0 commit comments