@@ -126,17 +126,10 @@ def nested_list_size(inputs: Sequence[Any]) -> tuple[list[int], int]:
126126
127127 return size , elem_bytes
128128
129- def get_layer_name (self , show_var_name : bool , show_depth : bool ) -> str :
130- layer_name = self .class_name
131- if show_var_name and self .var_name :
132- layer_name += f" ({ self .var_name } )"
133- if show_depth and self .depth > 0 :
134- layer_name += f": { self .depth } "
135- if self .depth_index is not None :
136- layer_name += f"-{ self .depth_index } "
137- return layer_name
138-
139- def get_param_count (self , name : str , param : torch .Tensor ) -> tuple [int , str ]:
129+ @staticmethod
130+ def get_param_count (
131+ module : nn .Module , name : str , param : torch .Tensor
132+ ) -> tuple [int , str ]:
140133 """
141134 Get count of number of params, accounting for mask.
142135
@@ -146,12 +139,22 @@ def get_param_count(self, name: str, param: torch.Tensor) -> tuple[int, str]:
146139 """
147140 if name .endswith ("_orig" ):
148141 without_suffix = name [:- 5 ]
149- pruned_weights = rgetattr (self . module , f"{ without_suffix } _mask" )
142+ pruned_weights = rgetattr (module , f"{ without_suffix } _mask" )
150143 if pruned_weights is not None :
151144 parameter_count = int (torch .sum (pruned_weights ))
152145 return parameter_count , without_suffix
153146 return param .nelement (), name
154147
148+ def get_layer_name (self , show_var_name : bool , show_depth : bool ) -> str :
149+ layer_name = self .class_name
150+ if show_var_name and self .var_name :
151+ layer_name += f" ({ self .var_name } )"
152+ if show_depth and self .depth > 0 :
153+ layer_name += f": { self .depth } "
154+ if self .depth_index is not None :
155+ layer_name += f"-{ self .depth_index } "
156+ return layer_name
157+
155158 def calculate_num_params (self ) -> None :
156159 """
157160 Set num_params, trainable, inner_layers, and kernel_size
@@ -161,7 +164,7 @@ def calculate_num_params(self) -> None:
161164 for name , param in self .module .named_parameters ():
162165 if is_lazy (param ):
163166 continue
164- cur_params , name = self .get_param_count (name , param )
167+ cur_params , name = self .get_param_count (self . module , name , param )
165168 self .param_bytes += param .element_size () * cur_params
166169
167170 self .num_params += cur_params
@@ -194,7 +197,7 @@ def calculate_macs(self) -> None:
194197 i.e., taking the batch-dimension into account.
195198 """
196199 for name , param in self .module .named_parameters ():
197- cur_params , name = self .get_param_count (name , param )
200+ cur_params , name = self .get_param_count (self . module , name , param )
198201 if name in ("weight" , "bias" ):
199202 # ignore C when calculating Mult-Adds in ConvNd
200203 if "Conv" in self .class_name :
0 commit comments