Skip to content

Commit e86fb34

Browse files
committed
Refactor backend tests and reporting
Introduce BACKEND_TESTS and BACKEND_DISPLAY_NAMES mappings to dispatch backend-specific tests and print user-friendly names. Replace string keys with Engine enum keys in backend_results and backend_failures, add type annotations, and centralize test invocation to reduce duplication. Improve error handling and warnings (including stacklevel), unify summary output formatting, and clean up minor control-flow and temporary-directory messages
1 parent cf3e57e commit e86fb34

1 file changed

Lines changed: 38 additions & 31 deletions

File tree

dlclive/check_install/check_install.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,19 @@ def run_tensorflow_test(video_file: str, display: bool = False):
8585
)
8686

8787

88+
BACKEND_TESTS = {
89+
Engine.PYTORCH: run_pytorch_test,
90+
Engine.TENSORFLOW: run_tensorflow_test,
91+
}
92+
BACKEND_DISPLAY_NAMES = {
93+
Engine.PYTORCH: "PyTorch",
94+
Engine.TENSORFLOW: "TensorFlow",
95+
}
96+
97+
8898
def main():
89-
backend_results = {}
99+
backend_results: dict[Engine, tuple[str, str | None]] = {}
100+
backend_failures: dict[Engine, Exception] = {}
90101

91102
parser = argparse.ArgumentParser(
92103
description="Test DLC-Live installation by downloading and evaluating a demo DLC project!"
@@ -111,7 +122,6 @@ def main():
111122
if not display:
112123
print("Running without displaying video")
113124

114-
# make temporary directory
115125
print("\nCreating temporary directory...\n")
116126
tmp_dir = TMP_DIR
117127
tmp_dir.mkdir(mode=0o775, exist_ok=True)
@@ -132,10 +142,9 @@ def main():
132142
else:
133143
print(f"Video file already exists at {video_file}, skipping download.")
134144

135-
# assert these things exist so we can give informative error messages
136145
if not Path(video_file).exists():
137146
raise FileNotFoundError(f"Missing video file {video_file}")
138-
backend_failures = {}
147+
139148
any_backend_succeeded = False
140149

141150
available_backends = get_available_backends()
@@ -147,28 +156,23 @@ def main():
147156
)
148157

149158
for backend in available_backends:
159+
test_func = BACKEND_TESTS.get(backend)
160+
if test_func is None:
161+
warnings.warn(
162+
f"No test function defined for backend {backend}, skipping...",
163+
UserWarning,
164+
stacklevel=2,
165+
)
166+
continue
167+
150168
try:
151-
if backend == Engine.PYTORCH:
152-
print("\nRunning PyTorch test...\n")
153-
run_pytorch_test(video_file, display=display)
154-
any_backend_succeeded = True
155-
backend_results["pytorch"] = ("SUCCESS", None)
156-
elif backend == Engine.TENSORFLOW:
157-
print("\nRunning TensorFlow test...\n")
158-
run_tensorflow_test(video_file, display=display)
159-
any_backend_succeeded = True
160-
backend_results["tensorflow"] = ("SUCCESS", None)
161-
else:
162-
warnings.warn(f"Unrecognized backend {backend}, skipping...", UserWarning, stacklevel=2)
169+
print(f"\nRunning {BACKEND_DISPLAY_NAMES.get(backend, backend.value)} test...\n")
170+
test_func(video_file, display=display)
171+
any_backend_succeeded = True
172+
backend_results[backend] = ("SUCCESS", None)
173+
163174
except Exception as e:
164-
backend_name = (
165-
"pytorch"
166-
if backend == Engine.PYTORCH
167-
else "tensorflow"
168-
if backend == Engine.TENSORFLOW
169-
else str(backend)
170-
)
171-
backend_results[backend_name] = ("ERROR", str(e))
175+
backend_results[backend] = ("ERROR", str(e))
172176
backend_failures[backend] = e
173177
warnings.warn(
174178
f"Error while running test for backend {backend}: {e}. "
@@ -178,16 +182,18 @@ def main():
178182
)
179183

180184
print("\n---\nBackend test summary:")
181-
for name in ("tensorflow", "pytorch"):
182-
status, _ = backend_results.get(name, ("SKIPPED", None))
183-
print(f"{name:<11} [{status}]")
185+
for backend in BACKEND_TESTS.keys():
186+
status, _ = backend_results.get(backend, ("SKIPPED", None))
187+
print(f"{backend.value:<11} [{status}]")
184188
print("---")
185-
for name, (status, error) in backend_results.items():
189+
190+
for backend, (status, error) in backend_results.items():
186191
if status == "ERROR":
187-
print(f"{name.capitalize()} error:\n{error}\n")
192+
backend_name = BACKEND_DISPLAY_NAMES.get(backend, backend.value)
193+
print(f"{backend_name} error:\n{error}\n")
188194

189195
if not any_backend_succeeded and backend_failures:
190-
failure_messages = "; ".join(f"{b}: {exc}" for b, exc in backend_failures.items())
196+
failure_messages = "; ".join(f"{b}: {e}" for b, e in backend_failures.items())
191197
raise RuntimeError(f"All backend tests failed. Details: {failure_messages}")
192198

193199
finally:
@@ -198,7 +204,8 @@ def main():
198204
shutil.rmtree(tmp_dir)
199205
except PermissionError:
200206
warnings.warn(
201-
f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error.", stacklevel=2
207+
f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error.",
208+
stacklevel=2,
202209
)
203210

204211

0 commit comments

Comments
 (0)