11from typing import Any
22
3- import pytest
43import torch
5- import torchvision # type: ignore[import]
64from torch import nn
75from torch .nn .utils import prune
86
97from tests .conftest import verify_output_str
10- from tests .fixtures .genotype import GenotypeNetwork # type: ignore[attr-defined]
118from tests .fixtures .models import (
129 AutoEncoder ,
1310 ContainerModule ,
3128 SiameseNets ,
3229 SingleInputNet ,
3330)
34- from tests .fixtures .tmva_net import TMVANet # type: ignore[attr-defined]
3531from torchinfo import ColumnSettings , summary
3632from torchinfo .enums import Verbosity
3733
@@ -141,33 +137,6 @@ def test_single_input_batch_dim() -> None:
141137 )
142138
143139
144- def test_frozen_layers () -> None :
145- model = torchvision .models .resnet18 ()
146- for ind , param in enumerate (model .parameters ()):
147- if ind < 30 :
148- param .requires_grad = False
149-
150- summary (
151- model ,
152- input_size = (1 , 3 , 64 , 64 ),
153- depth = 3 ,
154- col_names = ("output_size" , "num_params" , "kernel_size" , "mult_adds" ),
155- )
156-
157-
158- def test_resnet18_depth_consistency () -> None :
159- model = torchvision .models .resnet18 ()
160-
161- for depth in range (1 , 3 ):
162- summary (model , (1 , 3 , 64 , 64 ), depth = depth , cache_forward_pass = True )
163-
164-
165- def test_resnet152 () -> None :
166- model = torchvision .models .resnet152 ()
167-
168- summary (model , (1 , 3 , 224 , 224 ), depth = 3 )
169-
170-
171140def test_pruning () -> None :
172141 model = SingleInputNet ()
173142 for module in model .modules ():
@@ -287,18 +256,6 @@ def test_recursive() -> None:
287256 assert results .total_mult_adds == 173709312
288257
289258
290- def test_resnet () -> None :
291- # According to https://arxiv.org/abs/1605.07146,
292- # resnet50 has ~25.6 M trainable params.
293- model = torchvision .models .resnet50 ()
294- results = summary (model , input_size = (2 , 3 , 224 , 224 ))
295-
296- assert results .total_params == 25557032 # close to 25.6e6
297- assert results .total_mult_adds == sum (
298- layer .macs for layer in results .summary_list if layer .is_leaf_layer
299- )
300-
301-
302259def test_siamese_net () -> None :
303260 metrics = summary (SiameseNets (), input_size = [(1 , 1 , 88 , 88 ), (1 , 1 , 88 , 88 )])
304261
@@ -313,16 +270,6 @@ def test_empty_module() -> None:
313270 summary (EmptyModule ())
314271
315272
316- @pytest .mark .skip
317- def test_fasterrcnn () -> None :
318- model = torchvision .models .detection .fasterrcnn_resnet50_fpn (
319- pretrained_backbone = False
320- )
321- results = summary (model , input_size = (1 , 3 , 112 , 112 ))
322-
323- assert results .total_params == 41755286
324-
325-
326273def test_device () -> None :
327274 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
328275 model = SingleInputNet ()
@@ -351,7 +298,7 @@ def test_pack_padded() -> None:
351298 ]).long ()
352299 # fmt: on
353300
354- summary (PackPaddedLSTM (), input_data = x , lengths = y , device = "cpu" )
301+ summary (PackPaddedLSTM (), input_data = x , lengths = y )
355302
356303
357304def test_module_dict () -> None :
@@ -417,7 +364,7 @@ def test_namedtuple() -> None:
417364 model = NamedTuple ()
418365 input_size = [(2 , 1 , 28 , 28 ), (2 , 1 , 28 , 28 )]
419366 named_tuple = model .Point (* input_size )
420- summary (model , input_size = input_size , z = named_tuple , device = "cpu" )
367+ summary (model , input_size = input_size , z = named_tuple )
421368
422369
423370def test_return_dict () -> None :
@@ -432,61 +379,11 @@ def test_containers() -> None:
432379 summary (ContainerModule (), input_size = (5 ,))
433380
434381
435- def test_eval_order_doesnt_matter () -> None :
436- input_size = (1 , 3 , 224 , 224 )
437- input_tensor = torch .ones (input_size )
438-
439- model1 = torchvision .models .resnet18 (pretrained = True )
440- model1 .eval ()
441- summary (model1 , input_size = input_size , device = "cpu" )
442- with torch .inference_mode (): # type: ignore[no-untyped-call]
443- output1 = model1 (input_tensor )
444-
445- model2 = torchvision .models .resnet18 (pretrained = True )
446- summary (model2 , input_size = input_size , device = "cpu" )
447- model2 .eval ()
448- with torch .inference_mode (): # type: ignore[no-untyped-call]
449- output2 = model2 (input_tensor )
450-
451- assert torch .all (torch .eq (output1 , output2 ))
452-
453-
454382def test_autoencoder () -> None :
455383 model = AutoEncoder ()
456384 summary (model , input_size = (1 , 3 , 64 , 64 ))
457385
458386
459- def test_genotype () -> None :
460- model = GenotypeNetwork ()
461-
462- x = summary (model , (2 , 3 , 32 , 32 ), depth = 3 , cache_forward_pass = True )
463- y = summary (model , (2 , 3 , 32 , 32 ), depth = 7 , cache_forward_pass = True )
464-
465- assert x .total_params == y .total_params , (x , y )
466-
467-
468- def test_tmva_net_column_totals () -> None :
469- for depth in (1 , 3 , 5 ):
470- results = summary (
471- TMVANet (n_classes = 4 , n_frames = 5 ),
472- input_data = [
473- torch .randn (1 , 1 , 5 , 256 , 64 ),
474- torch .randn (1 , 1 , 5 , 256 , 256 ),
475- torch .randn (1 , 1 , 5 , 256 , 64 ),
476- ],
477- col_names = ["output_size" , "num_params" , "mult_adds" ],
478- depth = depth ,
479- cache_forward_pass = True ,
480- )
481-
482- assert results .total_params == sum (
483- layer .num_params for layer in results .summary_list if layer .is_leaf_layer
484- )
485- assert results .total_mult_adds == sum (
486- layer .macs for layer in results .summary_list if layer .is_leaf_layer
487- )
488-
489-
490387def test_reusing_activation_layers () -> None :
491388 act = nn .LeakyReLU (inplace = True )
492389 model1 = nn .Sequential (act , nn .Identity (), act , nn .Identity (), act ) # type: ignore[no-untyped-call] # noqa
@@ -511,21 +408,6 @@ def test_mixed_trainable_parameters() -> None:
511408 assert result .total_params == 20
512409
513410
514- def test_ascii_only () -> None :
515- result = summary (
516- torchvision .models .resnet18 (),
517- depth = 3 ,
518- input_size = (1 , 3 , 64 , 64 ),
519- row_settings = ["ascii_only" ],
520- )
521-
522- assert str (result ).encode ("ascii" ).decode ("ascii" )
523-
524-
525- def test_google () -> None :
526- summary (torchvision .models .googlenet (), (1 , 3 , 112 , 112 ), depth = 7 )
527-
528-
529411def test_too_many_linear () -> None :
530412 net = ReuseLinear ()
531413 summary (net , (2 , 10 ))
0 commit comments