Skip to content

Commit e7be90b

Browse files
committed
Do not error if there are None values in ModuleLists
1 parent 7def795 commit e7be90b

2 files changed

Lines changed: 16 additions & 13 deletions

File tree

tests/fixtures/models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,14 @@ def __init__(self) -> None:
407407
self._layers.append(nn.Linear(5, 5))
408408
self._layers.append(ContainerChildModule())
409409
self._layers.append(nn.Linear(5, 5))
410+
# Add None, but filter out this value later.
411+
self._layers.append(None) # type: ignore[arg-type]
410412

411413
def forward(self, x: torch.Tensor) -> torch.Tensor:
412414
out = x
413415
for layer in self._layers:
414-
out = layer(out)
416+
if layer is not None:
417+
out = layer(out)
415418
return out
416419

417420

torchinfo/torchinfo.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -566,18 +566,18 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None:
566566

567567
# module.named_modules(remove_duplicate=False) doesn't work (infinite recursion).
568568
for name, mod in module._modules.items(): # pylint: disable=protected-access
569-
assert mod is not None
570-
child = (name, mod)
571-
apply_hooks(
572-
child,
573-
orig_model,
574-
batch_dim,
575-
summary_list,
576-
hooks,
577-
all_layers,
578-
curr_depth + 1,
579-
info,
580-
)
569+
if mod is not None:
570+
child = (name, mod)
571+
apply_hooks(
572+
child,
573+
orig_model,
574+
batch_dim,
575+
summary_list,
576+
hooks,
577+
all_layers,
578+
curr_depth + 1,
579+
info,
580+
)
581581

582582

583583
def clear_cached_forward_pass() -> None:

0 commit comments

Comments
 (0)