Skip to content

Commit 2a215a9

Browse files
committed
Improve backend test robustness and imports
Refactor test startup and error handling for check_install and simplify a type-only import. - dlclive/benchmark.py: replace the try/except tensorflow import under TYPE_CHECKING with a direct import (type ignored) to simplify typing logic. - dlclive/check_install/check_install.py: - Defer importing export_modelzoo_model into run_pytorch_test to avoid importing heavy modules unless PyTorch test runs. - Move MODELS_FOLDER.mkdir to after temporary directory creation. - Add a --nodisplay flag and set default for --display to False so CLI can explicitly disable display. - Comment out resize parameters in test calls and remove an unnecessary model_dir.exists() assertion. - Wrap per-backend test runs in try/except, collect backend failures, allow other backends to continue, and raise an aggregated RuntimeError if all backend tests fail. These changes improve robustness when some backends fail or are unavailable and reduce unnecessary imports during initial checks.
1 parent f3c8a76 commit 2a215a9

2 files changed

Lines changed: 40 additions & 18 deletions

File tree

dlclive/benchmark.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@
3131
from dlclive.utils import decode_fourcc
3232

3333
if TYPE_CHECKING:
34-
try:
35-
import tensorflow
36-
except ImportError:
37-
tensorflow = None
34+
import tensorflow # type: ignore
3835

3936

4037
def download_benchmarking_data(

dlclive/check_install/check_install.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from dlclive.utils import download_file
1818
from dlclive.benchmark import benchmark_videos
1919
from dlclive.engine import Engine
20-
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
2120
from dlclive.utils import get_available_backends
2221

2322
MODEL_NAME = "superanimal_quadruped"
@@ -32,10 +31,10 @@
3231
}
3332
TF_MODEL_DIR = TMP_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
3433

35-
MODELS_FOLDER.mkdir(parents=True, exist_ok=True)
36-
3734

3835
def run_pytorch_test(video_file: str, display: bool = False):
36+
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
37+
3938
if Engine.PYTORCH not in get_available_backends():
4039
raise NotImplementedError(
4140
"PyTorch backend is not available. Please ensure PyTorch is installed to run the PyTorch test."
@@ -57,7 +56,7 @@ def run_pytorch_test(video_file: str, display: bool = False):
5756
model_type="pytorch",
5857
video_path=video_file,
5958
display=display,
60-
resize=0.5,
59+
# resize=0.5,
6160
pcutoff=0.25,
6261
pixels=1000,
6362
)
@@ -70,7 +69,6 @@ def run_tensorflow_test(video_file: str, display: bool = False):
7069
)
7170
model_dir = TF_MODEL_DIR
7271
model_dir.mkdir(parents=True, exist_ok=True)
73-
assert model_dir.exists(), f"Model directory {model_dir} does not exist"
7472
if Path(model_dir / SNAPSHOT_NAME).exists():
7573
print("Model already downloaded, using cached version")
7674
else:
@@ -88,7 +86,7 @@ def run_tensorflow_test(video_file: str, display: bool = False):
8886
model_type="base",
8987
video_path=video_file,
9088
display=display,
91-
resize=0.5,
89+
# resize=0.5,
9290
pcutoff=0.25,
9391
pixels=1000,
9492
)
@@ -103,8 +101,16 @@ def main():
103101
parser.add_argument(
104102
"--display",
105103
action="store_true",
104+
default=False,
106105
help="Run the test and display tracking",
107106
)
107+
parser.add_argument(
108+
"--nodisplay",
109+
action="store_false",
110+
dest="display",
111+
help=argparse.SUPPRESS,
112+
)
113+
108114
args = parser.parse_args()
109115
display = args.display
110116

@@ -115,6 +121,7 @@ def main():
115121
print("\nCreating temporary directory...\n")
116122
tmp_dir = TMP_DIR
117123
tmp_dir.mkdir(mode=0o775, exist_ok=True)
124+
MODELS_FOLDER.mkdir(parents=True, exist_ok=True)
118125

119126
video_file = str(tmp_dir / "dog_clip.avi")
120127

@@ -132,19 +139,37 @@ def main():
132139

133140
# assert these things exist so we can give informative error messages
134141
assert Path(video_file).exists(), f"Missing video file {video_file}"
142+
backend_failures = {}
143+
any_backend_succeeded = False
135144

136145
for backend in get_available_backends():
137-
if backend == Engine.PYTORCH:
138-
print("\nRunning PyTorch test...\n")
139-
run_pytorch_test(video_file, display=display)
140-
elif backend == Engine.TENSORFLOW:
141-
print("\nRunning TensorFlow test...\n")
142-
run_tensorflow_test(video_file, display=display)
143-
else:
146+
try:
147+
if backend == Engine.PYTORCH:
148+
print("\nRunning PyTorch test...\n")
149+
run_pytorch_test(video_file, display=display)
150+
any_backend_succeeded = True
151+
elif backend == Engine.TENSORFLOW:
152+
print("\nRunning TensorFlow test...\n")
153+
run_tensorflow_test(video_file, display=display)
154+
any_backend_succeeded = True
155+
else:
156+
warnings.warn(
157+
f"Unrecognized backend {backend}, skipping...", UserWarning
158+
)
159+
except Exception as e:
160+
backend_failures[backend] = e
144161
warnings.warn(
145-
f"Unrecognized backend {backend}, skipping...", UserWarning
162+
f"Error while running test for backend {backend}: {e}. "
163+
"Continuing to test other available backends.",
164+
UserWarning,
146165
)
147166

167+
if not any_backend_succeeded and backend_failures:
168+
failure_messages = "; ".join(
169+
f"{b}: {exc}" for b, exc in backend_failures.items()
170+
)
171+
raise RuntimeError(f"All backend tests failed. Details: {failure_messages}")
172+
148173
finally:
149174
# deleting temporary files
150175
print("\n Deleting temporary files...\n")

0 commit comments

Comments
 (0)