Skip to content

Commit 68b3902

Browse files
committed
Add PyTorch support and refactor check_install
Refactor check_install.py to support both PyTorch and TensorFlow backends and improve temporary file handling. Introduces TMP_DIR, MODELS_FOLDER, and separate run_pytorch_test/run_tensorflow_test helpers; PyTorch now exports and benchmarks an exported .pt checkpoint, TensorFlow model download logic is preserved. Replaces --nodisplay with --display, centralizes video download and assertions, tightens error handling for downloads, and ensures proper cleanup of temporary files. Also updates imports (urllib.error, export_modelzoo_model) and updates backend availability checks to require at least one backend.
1 parent 18913ac commit 68b3902

1 file changed

Lines changed: 119 additions & 54 deletions

File tree

dlclive/check_install/check_install.py

100755100644
Lines changed: 119 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,98 +9,163 @@
99
import shutil
1010
import urllib
1111
import warnings
12+
import urllib.error
1213
from pathlib import Path
1314

1415
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
1516

16-
import dlclive
17+
from dlclive.utils import download_file
1718
from dlclive.benchmark import benchmark_videos
1819
from dlclive.engine import Engine
19-
from dlclive.utils import download_file, get_available_backends
20+
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
21+
from dlclive.utils import get_available_backends
2022

2123
MODEL_NAME = "superanimal_quadruped"
2224
SNAPSHOT_NAME = "snapshot-700000.pb"
25+
TMP_DIR = Path(__file__).parent / "dlc-live-tmp"
2326

27+
MODELS_FOLDER = TMP_DIR / "test_models"
28+
TORCH_MODEL = "resnet_50"
29+
TORCH_CONFIG = {
30+
"checkpoint": MODELS_FOLDER / f"exported_quadruped_{TORCH_MODEL}.pt",
31+
"super_animal": "superanimal_quadruped",
32+
}
33+
TF_MODEL_DIR = TMP_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
2434

25-
def main():
26-
parser = argparse.ArgumentParser(
27-
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
28-
)
29-
parser.add_argument(
30-
"--nodisplay",
31-
action="store_false",
32-
help="Run the test without displaying tracking",
33-
)
34-
args = parser.parse_args()
35-
display = args.nodisplay
36-
37-
if not display:
38-
print("Running without displaying video")
35+
MODELS_FOLDER.mkdir(parents=True, exist_ok=True)
3936

40-
# make temporary directory
41-
print("\nCreating temporary directory...\n")
42-
tmp_dir = Path(dlclive.__file__).parent / "check_install" / "dlc-live-tmp"
43-
tmp_dir.mkdir(mode=0o775, exist_ok=True)
4437

45-
video_file = str(tmp_dir / "dog_clip.avi")
46-
model_dir = tmp_dir / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
38+
def run_pytorch_test(video_file: str, display: bool = False):
39+
if Engine.PYTORCH not in get_available_backends():
40+
raise NotImplementedError(
41+
"PyTorch backend is not available. Please ensure PyTorch is installed to run the PyTorch test."
42+
)
43+
# Download model from the DeepLabCut Model Zoo
44+
export_modelzoo_model(
45+
export_path=TORCH_CONFIG["checkpoint"],
46+
super_animal=TORCH_CONFIG["super_animal"],
47+
model_name=TORCH_MODEL,
48+
)
49+
assert TORCH_CONFIG["checkpoint"].exists(), (
50+
f"Failed to export {TORCH_CONFIG['super_animal']} model"
51+
)
52+
assert TORCH_CONFIG["checkpoint"].stat().st_size > 0, (
53+
f"Exported {TORCH_CONFIG['super_animal']} model is empty"
54+
)
55+
benchmark_videos(
56+
model_path=str(TORCH_CONFIG["checkpoint"]),
57+
model_type="pytorch",
58+
video_path=video_file,
59+
display=display,
60+
resize=0.5,
61+
pcutoff=0.25,
62+
pixels=1000,
63+
)
4764

48-
# download dog test video from github:
49-
# Use raw.githubusercontent.com for direct file access
50-
if not Path(video_file).exists():
51-
print(f"Downloading Video to {video_file}")
52-
url_link = "https://raw.githubusercontent.com/DeepLabCut/DeepLabCut-live/master/check_install/dog_clip.avi"
53-
try:
54-
download_file(url_link, video_file)
55-
except (OSError, urllib.error.URLError) as e:
56-
raise RuntimeError(f"Failed to download video file: {e}") from e
57-
else:
58-
print(f"Video file already exists at {video_file}, skipping download.")
5965

60-
# download model from the DeepLabCut Model Zoo
66+
def run_tensorflow_test(video_file: str, display: bool = False):
67+
if Engine.TENSORFLOW not in get_available_backends():
68+
raise NotImplementedError(
69+
"TensorFlow backend is not available. Please ensure TensorFlow is installed to run the TensorFlow test."
70+
)
71+
model_dir = TF_MODEL_DIR
72+
model_dir.mkdir(parents=True, exist_ok=True)
73+
assert model_dir.exists(), f"Model directory {model_dir} does not exist"
6174
if Path(model_dir / SNAPSHOT_NAME).exists():
6275
print("Model already downloaded, using cached version")
6376
else:
64-
print("Downloading superanimal_quadruped model from the DeepLabCut Model Zoo...")
65-
download_huggingface_model(MODEL_NAME, model_dir)
77+
print(
78+
"Downloading superanimal_quadruped model from the DeepLabCut Model Zoo..."
79+
)
80+
download_huggingface_model(MODEL_NAME, str(model_dir))
6681

67-
# assert these things exist so we can give informative error messages
68-
assert Path(video_file).exists(), f"Missing video file {video_file}"
69-
assert Path(model_dir / SNAPSHOT_NAME).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}"
82+
assert Path(model_dir / SNAPSHOT_NAME).exists(), (
83+
f"Missing model file {model_dir / SNAPSHOT_NAME}"
84+
)
7085

71-
# run benchmark videos
72-
print("\n Running inference...\n")
7386
benchmark_videos(
7487
model_path=str(model_dir),
75-
model_type="base" if Engine.from_model_path(model_dir) == Engine.TENSORFLOW else "pytorch",
88+
model_type="base",
7689
video_path=video_file,
7790
display=display,
7891
resize=0.5,
7992
pcutoff=0.25,
93+
pixels=1000,
8094
)
8195

82-
# deleting temporary files
83-
print("\n Deleting temporary files...\n")
96+
97+
def main():
98+
tmp_dir = None
8499
try:
85-
shutil.rmtree(tmp_dir)
86-
except PermissionError:
87-
warnings.warn(
88-
f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error.",
89-
stacklevel=2,
100+
parser = argparse.ArgumentParser(
101+
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
90102
)
103+
parser.add_argument(
104+
"--display",
105+
action="store_true",
106+
help="Run the test and display tracking",
107+
)
108+
args = parser.parse_args()
109+
display = args.display
110+
111+
if not display:
112+
print("Running without displaying video")
113+
114+
# make temporary directory
115+
print("\nCreating temporary directory...\n")
116+
tmp_dir = TMP_DIR
117+
tmp_dir.mkdir(mode=0o775, exist_ok=True)
118+
119+
video_file = str(tmp_dir / "dog_clip.avi")
120+
121+
# download dog test video from github:
122+
# Use raw.githubusercontent.com for direct file access
123+
if not Path(video_file).exists():
124+
print(f"Downloading Video to {video_file}")
125+
url_link = "https://raw.githubusercontent.com/DeepLabCut/DeepLabCut-live/master/check_install/dog_clip.avi"
126+
try:
127+
download_file(url_link, video_file)
128+
except (urllib.error.URLError, IOError) as e:
129+
raise RuntimeError(f"Failed to download video file: {e}") from e
130+
else:
131+
print(f"Video file already exists at {video_file}, skipping download.")
132+
133+
# assert these things exist so we can give informative error messages
134+
assert Path(video_file).exists(), f"Missing video file {video_file}"
135+
136+
for backend in get_available_backends():
137+
if backend == Engine.PYTORCH:
138+
print("\nRunning PyTorch test...\n")
139+
run_pytorch_test(video_file, display=display)
140+
elif backend == Engine.TENSORFLOW:
141+
print("\nRunning TensorFlow test...\n")
142+
run_tensorflow_test(video_file, display=display)
143+
else:
144+
warnings.warn(
145+
f"Unrecognized backend {backend}, skipping...", UserWarning
146+
)
147+
148+
finally:
149+
# deleting temporary files
150+
print("\n Deleting temporary files...\n")
151+
try:
152+
if tmp_dir is not None and tmp_dir.exists():
153+
shutil.rmtree(tmp_dir)
154+
except PermissionError:
155+
warnings.warn(
156+
f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error, but otherwise dlc-live seems to be working fine!"
157+
)
91158

92-
print("\nDone!\n")
159+
print("\nDone!\n")
93160

94161

95162
if __name__ == "__main__":
96163
# Get available backends (emits a warning if neither TensorFlow nor PyTorch is installed)
97164
available_backends: list[Engine] = get_available_backends()
98165
print(f"Available backends: {[b.value for b in available_backends]}")
99-
100-
# TODO: JR add support for PyTorch in check_install.py (requires some exported pytorch model to be downloaded)
101-
if Engine.TENSORFLOW not in available_backends:
166+
if len(available_backends) == 0:
102167
raise NotImplementedError(
103-
"TensorFlow is not installed. Currently check_install.py only supports testing the TensorFlow installation."
168+
"Neither TensorFlow nor PyTorch is installed. Please install at least one of these frameworks to run the installation test."
104169
)
105170

106171
main()

0 commit comments

Comments
 (0)