@@ -38,8 +38,8 @@ def __init__(
3838 else :
3939 if layer_info .is_recursive :
4040 continue
41- self .total_params += layer_info .remaining_params ()
42- self .trainable_params += layer_info .remaining_trainable_params ()
41+ self .total_params += layer_info .leftover_params ()
42+ self .trainable_params += layer_info .leftover_trainable_params ()
4343
4444 self .formatting .set_layer_name_width (summary_list )
4545
@@ -55,25 +55,19 @@ def __repr__(self) -> str:
5555 f"Non-trainable params: { self .total_params - self .trainable_params :,} \n "
5656 )
5757 if self .input_size :
58+ unit , macs = self .to_readable (self .total_mult_adds )
59+ input_size = self .to_megabytes (self .total_input )
60+ output_bytes = self .to_megabytes (self .total_output_bytes )
61+ param_bytes = self .to_megabytes (self .total_param_bytes )
62+ total_bytes = self .to_megabytes (
63+ self .total_input + self .total_output_bytes + self .total_param_bytes
64+ )
5865 summary_str += (
59- "Total mult-adds ({}): {:0.2f}\n {}\n " # pylint: disable=consider-using-f-string # noqa: E501
60- "Input size (MB): {:0.2f}\n "
61- "Forward/backward pass size (MB): {:0.2f}\n "
62- "Params size (MB): {:0.2f}\n "
63- "Estimated Total Size (MB): {:0.2f}\n " .format (
64- * self .to_readable (self .total_mult_adds ),
65- divider ,
66- self .to_megabytes (self .total_input ),
67- self .to_megabytes (self .total_output_bytes ),
68- self .to_megabytes (self .total_param_bytes ),
69- (
70- self .to_megabytes (
71- self .total_input
72- + self .total_output_bytes
73- + self .total_param_bytes
74- )
75- ),
76- )
66+ f"Total mult-adds ({ unit } ): { macs :0.2f} \n { divider } \n "
67+ f"Input size (MB): { input_size :0.2f} \n "
68+ f"Forward/backward pass size (MB): { output_bytes :0.2f} \n "
69+ f"Params size (MB): { param_bytes :0.2f} \n "
70+ f"Estimated Total Size (MB): { total_bytes :0.2f} \n "
7771 )
7872 summary_str += divider
7973 return summary_str
0 commit comments