Skip to content

Commit 0629b15

Browse files
committed
Improve backend test robustness and imports
Refactor test startup and error handling for check_install and simplify a type-only import. - dlclive/benchmark.py: replace the try/except tensorflow import under TYPE_CHECKING with a direct import (type ignored) to simplify typing logic. - dlclive/check_install/check_install.py: - Defer importing export_modelzoo_model into run_pytorch_test to avoid importing heavy modules unless PyTorch test runs. - Move MODELS_FOLDER.mkdir to after temporary directory creation. - Add a --nodisplay flag and set default for --display to False so CLI can explicitly disable display. - Comment out resize parameters in test calls and remove an unnecessary model_dir.exists() assertion. - Wrap per-backend test runs in try/except, collect backend failures, allow other backends to continue, and raise an aggregated RuntimeError if all backend tests fail. These changes improve robustness when some backends fail or are unavailable and reduce unnecessary imports during initial checks.
1 parent a4f6573 commit 0629b15

2 files changed

Lines changed: 40 additions & 18 deletions

File tree

dlclive/benchmark.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@
2929
from dlclive.utils import decode_fourcc
3030

3131
if TYPE_CHECKING:
32-
try:
33-
import tensorflow
34-
except ImportError:
35-
tensorflow = None
32+
import tensorflow # type: ignore
3633

3734

3835
def download_benchmarking_data(

dlclive/check_install/check_install.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from dlclive.utils import download_file
1717
from dlclive.benchmark import benchmark_videos
1818
from dlclive.engine import Engine
19-
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
2019
from dlclive.utils import get_available_backends
2120

2221
MODEL_NAME = "superanimal_quadruped"
@@ -31,10 +30,10 @@
3130
}
3231
TF_MODEL_DIR = TMP_DIR / "DLC_Dog_resnet_50_iteration-0_shuffle-0"
3332

34-
MODELS_FOLDER.mkdir(parents=True, exist_ok=True)
35-
3633

3734
def run_pytorch_test(video_file: str, display: bool = False):
35+
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
36+
3837
if Engine.PYTORCH not in get_available_backends():
3938
raise NotImplementedError(
4039
"PyTorch backend is not available. Please ensure PyTorch is installed to run the PyTorch test."
@@ -56,7 +55,7 @@ def run_pytorch_test(video_file: str, display: bool = False):
5655
model_type="pytorch",
5756
video_path=video_file,
5857
display=display,
59-
resize=0.5,
58+
# resize=0.5,
6059
pcutoff=0.25,
6160
pixels=1000,
6261
)
@@ -69,7 +68,6 @@ def run_tensorflow_test(video_file: str, display: bool = False):
6968
)
7069
model_dir = TF_MODEL_DIR
7170
model_dir.mkdir(parents=True, exist_ok=True)
72-
assert model_dir.exists(), f"Model directory {model_dir} does not exist"
7371
if Path(model_dir / SNAPSHOT_NAME).exists():
7472
print("Model already downloaded, using cached version")
7573
else:
@@ -87,7 +85,7 @@ def run_tensorflow_test(video_file: str, display: bool = False):
8785
model_type="base",
8886
video_path=video_file,
8987
display=display,
90-
resize=0.5,
88+
# resize=0.5,
9189
pcutoff=0.25,
9290
pixels=1000,
9391
)
@@ -102,8 +100,16 @@ def main():
102100
parser.add_argument(
103101
"--display",
104102
action="store_true",
103+
default=False,
105104
help="Run the test and display tracking",
106105
)
106+
parser.add_argument(
107+
"--nodisplay",
108+
action="store_false",
109+
dest="display",
110+
help=argparse.SUPPRESS,
111+
)
112+
107113
args = parser.parse_args()
108114
display = args.display
109115

@@ -114,6 +120,7 @@ def main():
114120
print("\nCreating temporary directory...\n")
115121
tmp_dir = TMP_DIR
116122
tmp_dir.mkdir(mode=0o775, exist_ok=True)
123+
MODELS_FOLDER.mkdir(parents=True, exist_ok=True)
117124

118125
video_file = str(tmp_dir / "dog_clip.avi")
119126

@@ -131,19 +138,37 @@ def main():
131138

132139
# assert these things exist so we can give informative error messages
133140
assert Path(video_file).exists(), f"Missing video file {video_file}"
141+
backend_failures = {}
142+
any_backend_succeeded = False
134143

135144
for backend in get_available_backends():
136-
if backend == Engine.PYTORCH:
137-
print("\nRunning PyTorch test...\n")
138-
run_pytorch_test(video_file, display=display)
139-
elif backend == Engine.TENSORFLOW:
140-
print("\nRunning TensorFlow test...\n")
141-
run_tensorflow_test(video_file, display=display)
142-
else:
145+
try:
146+
if backend == Engine.PYTORCH:
147+
print("\nRunning PyTorch test...\n")
148+
run_pytorch_test(video_file, display=display)
149+
any_backend_succeeded = True
150+
elif backend == Engine.TENSORFLOW:
151+
print("\nRunning TensorFlow test...\n")
152+
run_tensorflow_test(video_file, display=display)
153+
any_backend_succeeded = True
154+
else:
155+
warnings.warn(
156+
f"Unrecognized backend {backend}, skipping...", UserWarning
157+
)
158+
except Exception as e:
159+
backend_failures[backend] = e
143160
warnings.warn(
144-
f"Unrecognized backend {backend}, skipping...", UserWarning
161+
f"Error while running test for backend {backend}: {e}. "
162+
"Continuing to test other available backends.",
163+
UserWarning,
145164
)
146165

166+
if not any_backend_succeeded and backend_failures:
167+
failure_messages = "; ".join(
168+
f"{b}: {exc}" for b, exc in backend_failures.items()
169+
)
170+
raise RuntimeError(f"All backend tests failed. Details: {failure_messages}")
171+
147172
finally:
148173
# deleting temporary files
149174
print("\n Deleting temporary files...\n")

0 commit comments

Comments
 (0)