1212]
1313
1414
15- def rgetattr (obj : torch .nn .Module , attr : str ) -> torch .Tensor :
16- """Get the tensor submodule called attr from obj ."""
15+ def rgetattr (module : torch .nn .Module , attr : str ) -> torch .Tensor :
16+ """Get the tensor submodule called attr from module ."""
1717 for attr_i in attr .split ("." ):
18- obj = getattr (obj , attr_i )
19-
20- if isinstance (obj , torch .Tensor ):
21- return obj
22- else :
23- raise AttributeError (f"{ attr } is not a tensor" )
18+ module = getattr (module , attr_i )
19+ assert isinstance (module , torch .Tensor )
20+ return module
2421
2522
2623class LayerInfo :
@@ -127,18 +124,22 @@ def get_layer_name(self, show_var_name: bool, show_depth: bool) -> str:
127124 layer_name += f"-{ self .depth_index } "
128125 return layer_name
129126
130- def __get_cur_params (self , name : str , param : torch .Tensor ) -> tuple [int , str ]:
127+ def get_param_count (self , name : str , param : torch .Tensor ) -> tuple [int , str ]:
131128 """
132- Get count of number of params, accounting for mask
129+ Get count of number of params, accounting for mask.
130+
131+ Masked models save parameters with the suffix "_orig" added.
132+ They have a buffer ending with "_mask" which has only 0s and 1s.
133+ If a mask exists, the sum of 1s in mask is number of params.
133134 """
134- # Masked models save the parameter with the name "_orig" added
135- # They have a buffer ending with "_mask" which has only 0s and 1s
136- if name [ - 4 :] == "orig" :
137- # If a mask exists, the sum of 1s in mask is number of params
138- # Remove "_orig" for better readability and integration
139- return int ( torch . sum ( rgetattr ( self . module , f" { name [: - 4 ] } mask" ))), name [: - 5 ]
140- else :
141- return param .nelement (), name
135+ if name . endswith ( "orig" ):
136+ # Remove "_orig" suffix for better readability and integration
137+ without_suffix = name [: - 5 ]
138+ parameter_count = int (
139+ torch . sum ( rgetattr ( self . module , f" { without_suffix } _mask" ))
140+ )
141+ return parameter_count , without_suffix
142+ return param .nelement (), name
142143
143144 def calculate_num_params (self ) -> None :
144145 """
@@ -147,7 +148,7 @@ def calculate_num_params(self) -> None:
147148 """
148149 name = ""
149150 for name , param in self .module .named_parameters ():
150- cur_params , name = self .__get_cur_params (name , param )
151+ cur_params , name = self .get_param_count (name , param )
151152
152153 self .num_params += cur_params
153154 if param .requires_grad :
@@ -179,7 +180,7 @@ def calculate_macs(self) -> None:
179180 i.e., taking the batch-dimension into account.
180181 """
181182 for name , param in self .module .named_parameters ():
182- cur_params , name = self .__get_cur_params (name , param )
183+ cur_params , name = self .get_param_count (name , param )
183184 if name in ("weight" , "bias" ):
184185 # ignore C when calculating Mult-Adds in ConvNd
185186 if "Conv" in self .class_name :
0 commit comments