Skip to content

Commit cf3e57e

Browse files
committed
Use system temp dir and improve error handling
Use tempfile.gettempdir() for TMP_DIR instead of a repo-relative folder and add the tempfile import. Replace an assertion checking for the model snapshot with an explicit FileNotFoundError to provide clearer failure semantics. Set the CLI --display option default to False and add a guard that raises a RuntimeError when no available backends are detected, giving a helpful message to the user. Minor whitespace/flow adjustments around backend iteration.
1 parent 383f442 commit cf3e57e

1 file changed

Lines changed: 14 additions & 3 deletions

File tree

dlclive/check_install/check_install.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import argparse
99
import shutil
10+
import tempfile
1011
import urllib
1112
import urllib.error
1213
import warnings
@@ -20,7 +21,7 @@
2021

2122
MODEL_NAME = "superanimal_quadruped"
2223
SNAPSHOT_NAME = "snapshot-700000.pb"
23-
TMP_DIR = Path(__file__).parent / "dlc-live-tmp"
24+
TMP_DIR = Path(tempfile.gettempdir()) / "dlc-live-tmp"
2425

2526
MODELS_DIR = TMP_DIR / "test_models"
2627
TORCH_MODEL = "resnet_50"
@@ -71,7 +72,8 @@ def run_tensorflow_test(video_file: str, display: bool = False):
7172
print("Downloading superanimal_quadruped model from the DeepLabCut Model Zoo...")
7273
download_huggingface_model(MODEL_NAME, str(model_dir))
7374

74-
assert Path(model_dir / SNAPSHOT_NAME).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}"
75+
if not Path(model_dir / SNAPSHOT_NAME).exists():
76+
raise FileNotFoundError(f"Missing model file {model_dir / SNAPSHOT_NAME}")
7577

7678
benchmark_videos(
7779
model_path=str(model_dir),
@@ -100,6 +102,7 @@ def main():
100102
action="store_false",
101103
dest="display",
102104
help=argparse.SUPPRESS,
105+
default=False,
103106
)
104107

105108
args = parser.parse_args()
@@ -135,7 +138,15 @@ def main():
135138
backend_failures = {}
136139
any_backend_succeeded = False
137140

138-
for backend in get_available_backends():
141+
available_backends = get_available_backends()
142+
if not available_backends:
143+
raise RuntimeError(
144+
"No available backends to test. "
145+
"Please ensure that at least one of the supported backends "
146+
"(TensorFlow or PyTorch) is installed."
147+
)
148+
149+
for backend in available_backends:
139150
try:
140151
if backend == Engine.PYTORCH:
141152
print("\nRunning PyTorch test...\n")

0 commit comments

Comments
 (0)