Skip to content

Commit fbc3468

Browse files
committed
Convert get_param_count to @staticmethod
1 parent 39fe1d4 commit fbc3468

1 file changed

Lines changed: 17 additions & 14 deletions

File tree

torchinfo/layer_info.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,10 @@ def nested_list_size(inputs: Sequence[Any]) -> tuple[list[int], int]:
126126

127127
return size, elem_bytes
128128

129-
def get_layer_name(self, show_var_name: bool, show_depth: bool) -> str:
130-
layer_name = self.class_name
131-
if show_var_name and self.var_name:
132-
layer_name += f" ({self.var_name})"
133-
if show_depth and self.depth > 0:
134-
layer_name += f": {self.depth}"
135-
if self.depth_index is not None:
136-
layer_name += f"-{self.depth_index}"
137-
return layer_name
138-
139-
def get_param_count(self, name: str, param: torch.Tensor) -> tuple[int, str]:
129+
@staticmethod
130+
def get_param_count(
131+
module: nn.Module, name: str, param: torch.Tensor
132+
) -> tuple[int, str]:
140133
"""
141134
Get count of number of params, accounting for mask.
142135
@@ -146,12 +139,22 @@ def get_param_count(self, name: str, param: torch.Tensor) -> tuple[int, str]:
146139
"""
147140
if name.endswith("_orig"):
148141
without_suffix = name[:-5]
149-
pruned_weights = rgetattr(self.module, f"{without_suffix}_mask")
142+
pruned_weights = rgetattr(module, f"{without_suffix}_mask")
150143
if pruned_weights is not None:
151144
parameter_count = int(torch.sum(pruned_weights))
152145
return parameter_count, without_suffix
153146
return param.nelement(), name
154147

148+
def get_layer_name(self, show_var_name: bool, show_depth: bool) -> str:
149+
layer_name = self.class_name
150+
if show_var_name and self.var_name:
151+
layer_name += f" ({self.var_name})"
152+
if show_depth and self.depth > 0:
153+
layer_name += f": {self.depth}"
154+
if self.depth_index is not None:
155+
layer_name += f"-{self.depth_index}"
156+
return layer_name
157+
155158
def calculate_num_params(self) -> None:
156159
"""
157160
Set num_params, trainable, inner_layers, and kernel_size
@@ -161,7 +164,7 @@ def calculate_num_params(self) -> None:
161164
for name, param in self.module.named_parameters():
162165
if is_lazy(param):
163166
continue
164-
cur_params, name = self.get_param_count(name, param)
167+
cur_params, name = self.get_param_count(self.module, name, param)
165168
self.param_bytes += param.element_size() * cur_params
166169

167170
self.num_params += cur_params
@@ -194,7 +197,7 @@ def calculate_macs(self) -> None:
194197
i.e., taking the batch-dimension into account.
195198
"""
196199
for name, param in self.module.named_parameters():
197-
cur_params, name = self.get_param_count(name, param)
200+
cur_params, name = self.get_param_count(self.module, name, param)
198201
if name in ("weight", "bias"):
199202
# ignore C when calculating Mult-Adds in ConvNd
200203
if "Conv" in self.class_name:

0 commit comments

Comments
 (0)