Skip to content

Commit b07adc1

Browse files
committed
add functional benchmark test for pytorch model zoo
1 parent 3a39801 commit b07adc1

1 file changed

Lines changed: 58 additions & 3 deletions

File tree

tests/test_benchmark_script.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from dlclive.engine import Engine
55

66

7-
# TODO: JR include separate functional tests for torch and tf backends
8-
@pytest.mark.functional
9-
def test_benchmark_script_runs(tmp_path):
7+
@pytest.fixture
8+
def datafolder(tmp_path):
109
datafolder = tmp_path / "Data-DLC-live-benchmark"
1110
download_benchmarking_data(str(datafolder))
11+
return datafolder
1212

13+
@pytest.mark.functional
14+
def test_benchmark_script_runs_tf_backend(tmp_path, datafolder):
1315
dog_models = glob.glob(str(datafolder / "dog" / "*[!avi]"))
1416
dog_video = glob.glob(str(datafolder / "dog" / "*.avi"))[0]
1517
mouse_models = glob.glob(str(datafolder / "mouse_lick" / "*[!avi]"))
@@ -52,3 +54,56 @@ def test_benchmark_script_runs(tmp_path):
5254
)
5355

5456
assert any(out_dir.iterdir())
57+
58+
59+
@pytest.mark.parametrize("model_name", ["hrnet_w32", "resnet_50"])
60+
@pytest.mark.functional
61+
def test_benchmark_script_with_torch_modelzoo(tmp_path, datafolder, model_name):
62+
from dlclive import modelzoo
63+
64+
# Test configuration
65+
pixels = [100, 400]
66+
n_frames = 5
67+
out_dir = tmp_path / "results"
68+
out_dir.mkdir(exist_ok=True)
69+
70+
# Export models
71+
model_configs = [
72+
{
73+
"checkpoint": tmp_path / f"exported_quadruped_{model_name}.pt",
74+
"super_animal": "superanimal_quadruped",
75+
"video_dir": "dog",
76+
},
77+
{
78+
"checkpoint": tmp_path / f"exported_topviewmouse_{model_name}.pt",
79+
"super_animal": "superanimal_topviewmouse",
80+
"video_dir": "mouse_lick",
81+
},
82+
]
83+
84+
for config in model_configs:
85+
modelzoo.export_modelzoo_model(
86+
export_path=config["checkpoint"],
87+
super_animal=config["super_animal"],
88+
model_name=model_name,
89+
)
90+
assert config["checkpoint"].exists(), f"Failed to export {config['super_animal']} model"
91+
assert config["checkpoint"].stat().st_size > 0, f"Exported {config['super_animal']} model is empty"
92+
93+
# Get video paths and run benchmarks
94+
for config in model_configs:
95+
video_path = glob.glob(str(datafolder / config["video_dir"] / "*.avi"))[0]
96+
print(f"Running {config['model_display_name']}")
97+
benchmark_videos(
98+
model_path=config["checkpoint"],
99+
model_type="pytorch",
100+
video_path=video_path,
101+
output=str(out_dir),
102+
n_frames=n_frames,
103+
pixels=pixels,
104+
)
105+
106+
# Assertions: verify output files were created
107+
output_files = list(out_dir.iterdir())
108+
assert len(output_files) > 0, "No output files were created by benchmark_videos"
109+
assert any(f.suffix == ".pickle" for f in output_files), "No pickle files found in output directory"

0 commit comments

Comments
 (0)