Skip to content

Commit 4f94639

Browse files
committed
Refactor check_install tests and model paths
Rename MODELS_FOLDER to MODELS_DIR and update references (TORCH_CONFIG checkpoint and TF_MODEL_DIR) for clearer naming. Change missing-backend errors from NotImplementedError to ImportError to better reflect installation issues. Simplify main(): consolidate arg parsing, consistently create TMP_DIR and MODELS_DIR, and add backend_results tracking to report per-backend SUCCESS/ERROR statuses with a printed summary. Improve error recording for backend failures and adjust cleanup check when removing the temporary directory.
1 parent 2a215a9 commit 4f94639

1 file changed

Lines changed: 57 additions & 38 deletions

File tree

dlclive/check_install/check_install.py

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@
2323
SNAPSHOT_NAME = "snapshot-700000.pb"
2424
TMP_DIR = Path(__file__).parent / "dlc-live-tmp"
2525

26-
MODELS_FOLDER = TMP_DIR / "test_models"
26+
MODELS_DIR = TMP_DIR / "test_models"
2727
TORCH_MODEL = "resnet_50"
2828
TORCH_CONFIG = {
29-
"checkpoint": MODELS_FOLDER / f"exported_quadruped_{TORCH_MODEL}.pt",
29+
"checkpoint": MODELS_DIR / f"exported_quadruped_{TORCH_MODEL}.pt",
3030
"super_animal": "superanimal_quadruped",
3131
}
32-
TF_MODEL_DIR = TMP_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
32+
TF_MODEL_DIR = MODELS_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
3333

3434

3535
def run_pytorch_test(video_file: str, display: bool = False):
3636
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
3737

3838
if Engine.PYTORCH not in get_available_backends():
39-
raise NotImplementedError(
39+
raise ImportError(
4040
"PyTorch backend is not available. Please ensure PyTorch is installed to run the PyTorch test."
4141
)
4242
# Download model from the DeepLabCut Model Zoo
@@ -64,7 +64,7 @@ def run_pytorch_test(video_file: str, display: bool = False):
6464

6565
def run_tensorflow_test(video_file: str, display: bool = False):
6666
if Engine.TENSORFLOW not in get_available_backends():
67-
raise NotImplementedError(
67+
raise ImportError(
6868
"TensorFlow backend is not available. Please ensure TensorFlow is installed to run the TensorFlow test."
6969
)
7070
model_dir = TF_MODEL_DIR
@@ -93,38 +93,39 @@ def run_tensorflow_test(video_file: str, display: bool = False):
9393

9494

9595
def main():
96-
tmp_dir = None
97-
try:
98-
parser = argparse.ArgumentParser(
99-
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
100-
)
101-
parser.add_argument(
102-
"--display",
103-
action="store_true",
104-
default=False,
105-
help="Run the test and display tracking",
106-
)
107-
parser.add_argument(
108-
"--nodisplay",
109-
action="store_false",
110-
dest="display",
111-
help=argparse.SUPPRESS,
112-
)
96+
backend_results = {}
97+
98+
parser = argparse.ArgumentParser(
99+
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
100+
)
101+
parser.add_argument(
102+
"--display",
103+
action="store_true",
104+
default=False,
105+
help="Run the test and display tracking",
106+
)
107+
parser.add_argument(
108+
"--nodisplay",
109+
action="store_false",
110+
dest="display",
111+
help=argparse.SUPPRESS,
112+
)
113113

114-
args = parser.parse_args()
115-
display = args.display
114+
args = parser.parse_args()
115+
display = args.display
116116

117-
if not display:
118-
print("Running without displaying video")
117+
if not display:
118+
print("Running without displaying video")
119119

120-
# make temporary directory
121-
print("\nCreating temporary directory...\n")
122-
tmp_dir = TMP_DIR
123-
tmp_dir.mkdir(mode=0o775, exist_ok=True)
124-
MODELS_FOLDER.mkdir(parents=True, exist_ok=True)
120+
# make temporary directory
121+
print("\nCreating temporary directory...\n")
122+
tmp_dir = TMP_DIR
123+
tmp_dir.mkdir(mode=0o775, exist_ok=True)
124+
MODELS_DIR.mkdir(parents=True, exist_ok=True)
125125

126-
video_file = str(tmp_dir / "dog_clip.avi")
126+
video_file = str(tmp_dir / "dog_clip.avi")
127127

128+
try:
128129
# download dog test video from github:
129130
# Use raw.githubusercontent.com for direct file access
130131
if not Path(video_file).exists():
@@ -148,33 +149,51 @@ def main():
148149
print("\nRunning PyTorch test...\n")
149150
run_pytorch_test(video_file, display=display)
150151
any_backend_succeeded = True
152+
backend_results["pytorch"] = ("SUCCESS", None)
151153
elif backend == Engine.TENSORFLOW:
152154
print("\nRunning TensorFlow test...\n")
153155
run_tensorflow_test(video_file, display=display)
154156
any_backend_succeeded = True
157+
backend_results["tensorflow"] = ("SUCCESS", None)
155158
else:
156159
warnings.warn(
157160
f"Unrecognized backend {backend}, skipping...", UserWarning
158161
)
159162
except Exception as e:
163+
backend_name = (
164+
"pytorch" if backend == Engine.PYTORCH else
165+
"tensorflow" if backend == Engine.TENSORFLOW else
166+
str(backend)
167+
)
168+
backend_results[backend_name] = ("ERROR", str(e))
160169
backend_failures[backend] = e
161170
warnings.warn(
162171
f"Error while running test for backend {backend}: {e}. "
163172
"Continuing to test other available backends.",
164173
UserWarning,
165174
)
166175

167-
if not any_backend_succeeded and backend_failures:
168-
failure_messages = "; ".join(
169-
f"{b}: {exc}" for b, exc in backend_failures.items()
170-
)
171-
raise RuntimeError(f"All backend tests failed. Details: {failure_messages}")
176+
print("\n---\nBackend test summary:")
177+
for name in ("tensorflow", "pytorch"):
178+
status, _ = backend_results.get(name, ("SKIPPED", None))
179+
print(f"{name:<11} [{status}]")
180+
print("---")
181+
for name, (status, error) in backend_results.items():
182+
if status == "ERROR":
183+
print(f"{name.capitalize()} error:\n{error}\n")
184+
185+
if not any_backend_succeeded and backend_failures:
186+
failure_messages = "; ".join(
187+
f"{b}: {exc}" for b, exc in backend_failures.items()
188+
)
189+
raise RuntimeError(f"All backend tests failed. Details: {failure_messages}")
190+
172191

173192
finally:
174193
# deleting temporary files
175194
print("\n Deleting temporary files...\n")
176195
try:
177-
if tmp_dir is not None and tmp_dir.exists():
196+
if tmp_dir.exists():
178197
shutil.rmtree(tmp_dir)
179198
except PermissionError:
180199
warnings.warn(

0 commit comments

Comments
 (0)