-
Notifications
You must be signed in to change notification settings - Fork 54
Expand file tree
/
Copy pathtest_benchmark_script.py
More file actions
107 lines (89 loc) · 3.56 KB
/
test_benchmark_script.py
File metadata and controls
107 lines (89 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import glob
import pytest
from dlclive.benchmark import benchmark_videos, download_benchmarking_data
from dlclive.engine import Engine
@pytest.fixture
def datafolder(tmp_path):
datafolder = tmp_path / "Data-DLC-live-benchmark"
download_benchmarking_data(str(datafolder))
return datafolder
@pytest.mark.functional
@pytest.mark.slow
def test_benchmark_script_runs_tf_backend(tmp_path, datafolder):
dog_models = glob.glob(str(datafolder / "dog" / "*[!avi]"))
dog_video = glob.glob(str(datafolder / "dog" / "*.avi"))[0]
mouse_models = glob.glob(str(datafolder / "mouse_lick" / "*[!avi]"))
mouse_video = glob.glob(str(datafolder / "mouse_lick" / "*.avi"))[0]
out_dir = tmp_path / "results"
out_dir.mkdir(exist_ok=True)
pixels = [100, 400] # [2500, 10000]
n_frames = 5
for model_path in dog_models:
print(f"Running dog model: {model_path}")
benchmark_videos(
model_path=model_path,
model_type=("base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch"),
video_path=dog_video,
output=str(out_dir),
n_frames=n_frames,
pixels=pixels,
)
for model_path in mouse_models:
print(f"Running mouse model: {model_path}")
benchmark_videos(
model_path=model_path,
model_type=("base" if Engine.from_model_path(model_path) == Engine.TENSORFLOW else "pytorch"),
video_path=mouse_video,
output=str(out_dir),
n_frames=n_frames,
pixels=pixels,
)
assert any(out_dir.iterdir())
@pytest.mark.parametrize("model_name", ["hrnet_w32", "resnet_50"])
@pytest.mark.functional
@pytest.mark.slow
def test_benchmark_script_with_torch_modelzoo(tmp_path, datafolder, model_name):
from dlclive import modelzoo
# Test configuration
pixels = 4096 # approximately 64x64 pixels, keeping aspect ratio
n_frames = 5
out_dir = tmp_path / "results"
out_dir.mkdir(exist_ok=True)
# Export models
model_configs = [
{
"checkpoint": tmp_path / f"exported_quadruped_{model_name}.pt",
"super_animal": "superanimal_quadruped",
"video_dir": "dog",
},
{
"checkpoint": tmp_path / f"exported_topviewmouse_{model_name}.pt",
"super_animal": "superanimal_topviewmouse",
"video_dir": "mouse_lick",
},
]
for config in model_configs:
modelzoo.export_modelzoo_model(
export_path=config["checkpoint"],
super_animal=config["super_animal"],
model_name=model_name,
)
assert config["checkpoint"].exists(), f"Failed to export {config['super_animal']} model"
assert config["checkpoint"].stat().st_size > 0, f"Exported {config['super_animal']} model is empty"
# Get video paths and run benchmarks
for config in model_configs:
video_dir = datafolder / config["video_dir"]
video_path = list(video_dir.glob("*.avi"))[0]
print(f"Running {config['checkpoint'].stem}")
benchmark_videos(
model_path=config["checkpoint"],
model_type="pytorch",
video_path=video_path,
output=str(out_dir),
n_frames=n_frames,
pixels=pixels,
)
# Assertions: verify output files were created
output_files = list(out_dir.iterdir())
assert len(output_files) > 0, "No output files were created by benchmark_videos"
assert any(f.suffix == ".pickle" for f in output_files), "No pickle files found in output directory"