Skip to content

Commit f68c34c

Browse files
committed
Add PyTorch support and refactor check_install
Refactor check_install.py to support both PyTorch and TensorFlow backends and improve temporary file handling. Introduces TMP_DIR, MODELS_FOLDER, and separate run_pytorch_test/run_tensorflow_test helpers; PyTorch now exports and benchmarks an exported .pt checkpoint, TensorFlow model download logic is preserved. Replaces --nodisplay with --display, centralizes video download and assertions, tightens error handling for downloads, and ensures proper cleanup of temporary files. Also updates imports (urllib.error, export_modelzoo_model) and updates backend availability checks to require at least one backend.
1 parent 45715db commit f68c34c

1 file changed

Lines changed: 118 additions & 56 deletions

File tree

dlclive/check_install/check_install.py

Lines changed: 118 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,101 +8,163 @@
88
import argparse
99
import shutil
1010
import warnings
11+
import urllib.error
1112
from pathlib import Path
1213

1314
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
1415

15-
import dlclive
1616
from dlclive.utils import download_file
1717
from dlclive.benchmark import benchmark_videos
1818
from dlclive.engine import Engine
19+
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
1920
from dlclive.utils import get_available_backends
2021

2122
MODEL_NAME = "superanimal_quadruped"
2223
SNAPSHOT_NAME = "snapshot-700000.pb"
24+
TMP_DIR = Path(__file__).parent / "dlc-live-tmp"
2325

26+
MODELS_FOLDER = TMP_DIR / "test_models"
27+
TORCH_MODEL = "resnet_50"
28+
TORCH_CONFIG = {
29+
"checkpoint": MODELS_FOLDER / f"exported_quadruped_{TORCH_MODEL}.pt",
30+
"super_animal": "superanimal_quadruped",
31+
}
32+
TF_MODEL_DIR = TMP_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
2433

25-
def main():
26-
parser = argparse.ArgumentParser(
27-
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
28-
)
29-
parser.add_argument(
30-
"--nodisplay",
31-
action="store_false",
32-
help="Run the test without displaying tracking",
33-
)
34-
args = parser.parse_args()
35-
display = args.nodisplay
36-
37-
if not display:
38-
print("Running without displaying video")
34+
MODELS_FOLDER.mkdir(parents=True, exist_ok=True)
3935

40-
# make temporary directory
41-
print("\nCreating temporary directory...\n")
42-
tmp_dir = Path(dlclive.__file__).parent / "check_install" / "dlc-live-tmp"
43-
tmp_dir.mkdir(mode=0o775, exist_ok=True)
4436

45-
video_file = str(tmp_dir / "dog_clip.avi")
46-
model_dir = tmp_dir / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
37+
def run_pytorch_test(video_file: str, display: bool = False):
38+
if Engine.PYTORCH not in get_available_backends():
39+
raise NotImplementedError(
40+
"PyTorch backend is not available. Please ensure PyTorch is installed to run the PyTorch test."
41+
)
42+
# Download model from the DeepLabCut Model Zoo
43+
export_modelzoo_model(
44+
export_path=TORCH_CONFIG["checkpoint"],
45+
super_animal=TORCH_CONFIG["super_animal"],
46+
model_name=TORCH_MODEL,
47+
)
48+
assert TORCH_CONFIG["checkpoint"].exists(), (
49+
f"Failed to export {TORCH_CONFIG['super_animal']} model"
50+
)
51+
assert TORCH_CONFIG["checkpoint"].stat().st_size > 0, (
52+
f"Exported {TORCH_CONFIG['super_animal']} model is empty"
53+
)
54+
benchmark_videos(
55+
model_path=str(TORCH_CONFIG["checkpoint"]),
56+
model_type="pytorch",
57+
video_path=video_file,
58+
display=display,
59+
resize=0.5,
60+
pcutoff=0.25,
61+
pixels=1000,
62+
)
4763

48-
# download dog test video from github:
49-
# Use raw.githubusercontent.com for direct file access
50-
if not Path(video_file).exists():
51-
print(f"Downloading Video to {video_file}")
52-
url_link = "https://raw.githubusercontent.com/DeepLabCut/DeepLabCut-live/master/check_install/dog_clip.avi"
53-
try:
54-
download_file(url_link, video_file)
55-
except (urllib.error.URLError, IOError) as e:
56-
raise RuntimeError(f"Failed to download video file: {e}") from e
57-
else:
58-
print(f"Video file already exists at {video_file}, skipping download.")
5964

60-
# download model from the DeepLabCut Model Zoo
65+
def run_tensorflow_test(video_file: str, display: bool = False):
66+
if Engine.TENSORFLOW not in get_available_backends():
67+
raise NotImplementedError(
68+
"TensorFlow backend is not available. Please ensure TensorFlow is installed to run the TensorFlow test."
69+
)
70+
model_dir = TF_MODEL_DIR
71+
model_dir.mkdir(parents=True, exist_ok=True)
72+
assert model_dir.exists(), f"Model directory {model_dir} does not exist"
6173
if Path(model_dir / SNAPSHOT_NAME).exists():
6274
print("Model already downloaded, using cached version")
6375
else:
64-
print("Downloading superanimal_quadruped model from the DeepLabCut Model Zoo...")
65-
download_huggingface_model(MODEL_NAME, model_dir)
76+
print(
77+
"Downloading superanimal_quadruped model from the DeepLabCut Model Zoo..."
78+
)
79+
download_huggingface_model(MODEL_NAME, str(model_dir))
6680

67-
# assert these things exist so we can give informative error messages
68-
assert Path(video_file).exists(), f"Missing video file {video_file}"
69-
assert Path(
70-
model_dir / SNAPSHOT_NAME
71-
).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}"
81+
assert Path(model_dir / SNAPSHOT_NAME).exists(), (
82+
f"Missing model file {model_dir / SNAPSHOT_NAME}"
83+
)
7284

73-
# run benchmark videos
74-
print("\n Running inference...\n")
7585
benchmark_videos(
7686
model_path=str(model_dir),
77-
model_type="base" if Engine.from_model_path(model_dir) == Engine.TENSORFLOW else "pytorch",
87+
model_type="base",
7888
video_path=video_file,
7989
display=display,
8090
resize=0.5,
81-
pcutoff=0.25
91+
pcutoff=0.25,
92+
pixels=1000,
8293
)
8394

84-
# deleting temporary files
85-
print("\n Deleting temporary files...\n")
95+
96+
def main():
97+
tmp_dir = None
8698
try:
87-
shutil.rmtree(tmp_dir)
88-
except PermissionError:
89-
warnings.warn(
90-
f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error, but otherwise dlc-live seems to be working fine!"
99+
parser = argparse.ArgumentParser(
100+
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
101+
)
102+
parser.add_argument(
103+
"--display",
104+
action="store_true",
105+
help="Run the test and display tracking",
91106
)
107+
args = parser.parse_args()
108+
display = args.display
109+
110+
if not display:
111+
print("Running without displaying video")
112+
113+
# make temporary directory
114+
print("\nCreating temporary directory...\n")
115+
tmp_dir = TMP_DIR
116+
tmp_dir.mkdir(mode=0o775, exist_ok=True)
117+
118+
video_file = str(tmp_dir / "dog_clip.avi")
119+
120+
# download dog test video from github:
121+
# Use raw.githubusercontent.com for direct file access
122+
if not Path(video_file).exists():
123+
print(f"Downloading Video to {video_file}")
124+
url_link = "https://raw.githubusercontent.com/DeepLabCut/DeepLabCut-live/master/check_install/dog_clip.avi"
125+
try:
126+
download_file(url_link, video_file)
127+
except (urllib.error.URLError, IOError) as e:
128+
raise RuntimeError(f"Failed to download video file: {e}") from e
129+
else:
130+
print(f"Video file already exists at {video_file}, skipping download.")
131+
132+
# assert these things exist so we can give informative error messages
133+
assert Path(video_file).exists(), f"Missing video file {video_file}"
134+
135+
for backend in get_available_backends():
136+
if backend == Engine.PYTORCH:
137+
print("\nRunning PyTorch test...\n")
138+
run_pytorch_test(video_file, display=display)
139+
elif backend == Engine.TENSORFLOW:
140+
print("\nRunning TensorFlow test...\n")
141+
run_tensorflow_test(video_file, display=display)
142+
else:
143+
warnings.warn(
144+
f"Unrecognized backend {backend}, skipping...", UserWarning
145+
)
146+
147+
finally:
148+
# deleting temporary files
149+
print("\n Deleting temporary files...\n")
150+
try:
151+
if tmp_dir is not None and tmp_dir.exists():
152+
shutil.rmtree(tmp_dir)
153+
except PermissionError:
154+
warnings.warn(
155+
f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error, but otherwise dlc-live seems to be working fine!"
156+
)
92157

93-
print("\nDone!\n")
158+
print("\nDone!\n")
94159

95160

96161
if __name__ == "__main__":
97-
98162
# Get available backends (emits a warning if neither TensorFlow nor PyTorch is installed)
99163
available_backends: list[Engine] = get_available_backends()
100164
print(f"Available backends: {[b.value for b in available_backends]}")
101-
102-
# TODO: JR add support for PyTorch in check_install.py (requires some exported pytorch model to be downloaded)
103-
if not Engine.TENSORFLOW in available_backends:
165+
if len(available_backends) == 0:
104166
raise NotImplementedError(
105-
"TensorFlow is not installed. Currently check_install.py only supports testing the TensorFlow installation."
167+
"Neither TensorFlow nor PyTorch is installed. Please install at least one of these frameworks to run the installation test."
106168
)
107169

108170
main()

0 commit comments

Comments
 (0)