44from dlclive .engine import Engine
55
66
7- # TODO: JR include separate functional tests for torch and tf backends
8- @pytest .mark .functional
9- def test_benchmark_script_runs (tmp_path ):
7+ @pytest .fixture
8+ def datafolder (tmp_path ):
109 datafolder = tmp_path / "Data-DLC-live-benchmark"
1110 download_benchmarking_data (str (datafolder ))
11+ return datafolder
1212
13+ @pytest .mark .functional
14+ def test_benchmark_script_runs_tf_backend (tmp_path , datafolder ):
1315 dog_models = glob .glob (str (datafolder / "dog" / "*[!avi]" ))
1416 dog_video = glob .glob (str (datafolder / "dog" / "*.avi" ))[0 ]
1517 mouse_models = glob .glob (str (datafolder / "mouse_lick" / "*[!avi]" ))
@@ -52,3 +54,56 @@ def test_benchmark_script_runs(tmp_path):
5254 )
5355
5456 assert any (out_dir .iterdir ())
57+
58+
59+ @pytest .mark .parametrize ("model_name" , ["hrnet_w32" , "resnet_50" ])
60+ @pytest .mark .functional
61+ def test_benchmark_script_with_torch_modelzoo (tmp_path , datafolder , model_name ):
62+ from dlclive import modelzoo
63+
64+ # Test configuration
65+ pixels = [100 , 400 ]
66+ n_frames = 5
67+ out_dir = tmp_path / "results"
68+ out_dir .mkdir (exist_ok = True )
69+
70+ # Export models
71+ model_configs = [
72+ {
73+ "checkpoint" : tmp_path / f"exported_quadruped_{ model_name } .pt" ,
74+ "super_animal" : "superanimal_quadruped" ,
75+ "video_dir" : "dog" ,
76+ },
77+ {
78+ "checkpoint" : tmp_path / f"exported_topviewmouse_{ model_name } .pt" ,
79+ "super_animal" : "superanimal_topviewmouse" ,
80+ "video_dir" : "mouse_lick" ,
81+ },
82+ ]
83+
84+ for config in model_configs :
85+ modelzoo .export_modelzoo_model (
86+ export_path = config ["checkpoint" ],
87+ super_animal = config ["super_animal" ],
88+ model_name = model_name ,
89+ )
90+ assert config ["checkpoint" ].exists (), f"Failed to export { config ['super_animal' ]} model"
91+ assert config ["checkpoint" ].stat ().st_size > 0 , f"Exported { config ['super_animal' ]} model is empty"
92+
93+ # Get video paths and run benchmarks
94+ for config in model_configs :
95+ video_path = glob .glob (str (datafolder / config ["video_dir" ] / "*.avi" ))[0 ]
96+ print (f"Running { config ['model_display_name' ]} " )
97+ benchmark_videos (
98+ model_path = config ["checkpoint" ],
99+ model_type = "pytorch" ,
100+ video_path = video_path ,
101+ output = str (out_dir ),
102+ n_frames = n_frames ,
103+ pixels = pixels ,
104+ )
105+
106+ # Assertions: verify output files were created
107+ output_files = list (out_dir .iterdir ())
108+ assert len (output_files ) > 0 , "No output files were created by benchmark_videos"
109+ assert any (f .suffix == ".pickle" for f in output_files ), "No pickle files found in output directory"
0 commit comments