11""" layer_info.py """
2- from typing import Any , Dict , Iterable , List , Optional , Sequence , Union
2+ from __future__ import annotations
3+
4+ from typing import Any , Dict , Iterable , Sequence , Union
35
46import torch
57from torch import nn
@@ -18,8 +20,8 @@ def __init__(
1820 var_name : str ,
1921 module : nn .Module ,
2022 depth : int ,
21- depth_index : Optional [ int ] = None ,
22- parent_info : Optional [ " LayerInfo" ] = None ,
23+ depth_index : int | None = None ,
24+ parent_info : LayerInfo | None = None ,
2325 ) -> None :
2426 # Identifying information
2527 self .layer_id = id (module )
@@ -30,7 +32,7 @@ def __init__(
3032 else module .__class__ .__name__
3133 )
3234 # {layer name: {row_name: row_value}}
33- self .inner_layers : Dict [str , Dict [str , Any ]] = {}
35+ self .inner_layers : dict [str , dict [str , Any ]] = {}
3436 self .depth = depth
3537 self .depth_index = depth_index
3638 self .executed = False
@@ -41,9 +43,9 @@ def __init__(
4143 # Statistics
4244 self .trainable_params = 0
4345 self .is_recursive = False
44- self .input_size : List [int ] = []
45- self .output_size : List [int ] = []
46- self .kernel_size : List [int ] = []
46+ self .input_size : list [int ] = []
47+ self .output_size : list [int ] = []
48+ self .kernel_size : list [int ] = []
4749 self .num_params = 0
4850 self .macs = 0
4951
@@ -52,11 +54,11 @@ def __repr__(self) -> str:
5254
5355 @staticmethod
5456 def calculate_size (
55- inputs : DETECTED_INPUT_OUTPUT_TYPES , batch_dim : Optional [ int ]
56- ) -> List [int ]:
57+ inputs : DETECTED_INPUT_OUTPUT_TYPES , batch_dim : int | None
58+ ) -> list [int ]:
5759 """Set input_size or output_size using the model's inputs."""
5860
59- def nested_list_size (inputs : Sequence [Any ]) -> List [int ]:
61+ def nested_list_size (inputs : Sequence [Any ]) -> list [int ]:
6062 """Flattens nested list size."""
6163 if hasattr (inputs , "tensors" ):
6264 return nested_list_size (inputs .tensors ) # type: ignore[attr-defined]
@@ -164,7 +166,7 @@ def calculate_macs(self) -> None:
164166 elif "weight" in name or "bias" in name :
165167 self .macs += prod (self .output_size [:2 ]) * param .nelement ()
166168
167- def check_recursive (self , summary_list : List [ " LayerInfo" ]) -> None :
169+ def check_recursive (self , summary_list : list [ LayerInfo ]) -> None :
168170 """
169171 If the current module is already-used, mark as (recursive).
170172 Must check before adding line to the summary.
@@ -175,7 +177,7 @@ def check_recursive(self, summary_list: List["LayerInfo"]) -> None:
175177 self .is_recursive = True
176178
177179 def macs_to_str (
178- self , reached_max_depth : bool , children_layers : List [ " LayerInfo" ]
180+ self , reached_max_depth : bool , children_layers : list [ LayerInfo ]
179181 ) -> str :
180182 """Convert MACs to string."""
181183 if self .macs <= 0 :
@@ -199,7 +201,7 @@ def num_params_to_str(self, reached_max_depth: bool) -> str:
199201 return "--"
200202
201203
202- def prod (num_list : Union [ Iterable [int ], torch .Size ] ) -> int :
204+ def prod (num_list : Iterable [int ] | torch .Size ) -> int :
203205 result = 1
204206 if isinstance (num_list , Iterable ):
205207 for item in num_list :
0 commit comments