Skip to content

Commit 398eae2

Browse files
committed
Fix rebase issues
Clean up and minor fixes in dlclive/benchmark.py and dlclive/check_install/check_install.py. Changes include: reorder and consolidate imports; remove duplicated/extra blank lines and repeated statements (duplicate prints, duplicated init_inference call, duplicated function args); normalize long-line wrapping; switch some prints to f-strings; tighten exception handling for video download (catch OSError and URLError); add stacklevel to warnings.warn calls; and minor formatting improvements (single-line tuples, joined long expressions). These are non-functional refactors and small bug/safety fixes to reduce redundancy and improve diagnostics.
1 parent dbf9840 commit 398eae2

2 files changed

Lines changed: 24 additions & 72 deletions

File tree

dlclive/benchmark.py

Lines changed: 12 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
import sys
1414
import time
1515
import warnings
16-
from typing import TYPE_CHECKING
1716
from pathlib import Path
18-
import argparse
19-
import os
17+
from typing import TYPE_CHECKING
18+
2019
import colorcet as cc
2120
import cv2
2221
import numpy as np
@@ -25,8 +24,7 @@
2524
from pip._internal.operations import freeze
2625
from tqdm import tqdm
2726

28-
from dlclive import DLCLive
29-
from dlclive import VERSION
27+
from dlclive import VERSION, DLCLive
3028
from dlclive.engine import Engine
3129
from dlclive.utils import decode_fourcc
3230

@@ -56,20 +54,16 @@ def download_benchmarking_data(
5654
print(f"{zip_path} already exists. Skipping download.")
5755
else:
5856

59-
6057
def show_progress(count, block_size, total_size):
6158
pbar.update(block_size)
6259

6360
print(f"Downloading the benchmarking data from {url} ...")
6461
pbar = tqdm(unit="B", total=0, position=0, desc="Downloading")
6562

66-
filename, _ = urllib.request.urlretrieve(
67-
url, filename=zip_path, reporthook=show_progress
68-
)
63+
filename, _ = urllib.request.urlretrieve(url, filename=zip_path, reporthook=show_progress)
6964
pbar.close()
7065

7166
print(f"Extracting {zip_path} to {target_dir} ...")
72-
with zipfile.ZipFile(zip_path, "r") as zip_ref:
7367
with zipfile.ZipFile(zip_path, "r") as zip_ref:
7468
zip_ref.extractall(target_dir)
7569

@@ -192,7 +186,6 @@ def benchmark_videos(
192186

193187
for i in range(len(resize)):
194188
print(f"\nRun {i + 1} / {len(resize)}\n")
195-
print(f"\nRun {i + 1} / {len(resize)}\n")
196189

197190
this_inf_times, this_im_size, meta = benchmark(
198191
model_path=model_path,
@@ -296,7 +289,6 @@ def get_system_info() -> dict:
296289
}
297290

298291

299-
def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=None):
300292
def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=None):
301293
"""Save inference time data collected using :function:`benchmark` with system information to a pickle file.
302294
This is primarily used through :function:`benchmark_videos`
@@ -366,7 +358,6 @@ def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=N
366358
return True
367359

368360

369-
370361
def benchmark(
371362
model_path: str,
372363
model_type: str,
@@ -380,8 +371,6 @@ def benchmark(
380371
dynamic: tuple[bool, float, int] = (False, 0.5, 10),
381372
n_frames: int = 1000,
382373
print_rate: bool = False,
383-
n_frames: int = 1000,
384-
print_rate: bool = False,
385374
precision: str = "FP32",
386375
display: bool = True,
387376
pcutoff: float = 0.5,
@@ -526,9 +515,7 @@ def benchmark(
526515
frame_index = 0
527516

528517
total_n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
529-
n_frames = int(
530-
n_frames if (n_frames > 0) and n_frames < total_n_frames else total_n_frames
531-
)
518+
n_frames = int(n_frames if (n_frames > 0) and n_frames < total_n_frames else total_n_frames)
532519
iterator = range(n_frames) if print_rate or display else tqdm(range(n_frames))
533520
for _ in iterator:
534521
ret, frame = cap.read()
@@ -543,7 +530,6 @@ def benchmark(
543530
start_time = time.perf_counter()
544531
if frame_index == 0:
545532
pose = dlc_live.init_inference(frame) # Loads model
546-
pose = dlc_live.init_inference(frame) # Loads model
547533
else:
548534
pose = dlc_live.get_pose(frame)
549535

@@ -552,9 +538,7 @@ def benchmark(
552538
times.append(inf_time)
553539

554540
if print_rate:
555-
print(
556-
"Inference rate = {:.3f} FPS".format(1 / inf_time), end="\r", flush=True
557-
)
541+
print(f"Inference rate = {1 / inf_time:.3f} FPS", end="\r", flush=True)
558542

559543
if save_video:
560544
draw_pose_and_write(
@@ -567,15 +551,12 @@ def benchmark(
567551
display_radius=display_radius,
568552
draw_keypoint_names=draw_keypoint_names,
569553
vwriter=vwriter,
570-
vwriter=vwriter,
571554
)
572555

573556
frame_index += 1
574557

575558
if print_rate:
576-
print(
577-
"Mean inference rate: {:.3f} FPS".format(np.mean(1 / np.array(times)[1:]))
578-
)
559+
print(f"Mean inference rate: {np.mean(1 / np.array(times)[1:]):.3f} FPS")
579560

580561
metadata = _get_metadata(video_path=video_path, cap=cap, dlc_live=dlc_live)
581562
metadata = _get_metadata(video_path=video_path, cap=cap, dlc_live=dlc_live)
@@ -593,9 +574,7 @@ def benchmark(
593574
else:
594575
individuals = []
595576
n_individuals = len(individuals) or 1
596-
save_poses_to_files(
597-
video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp
598-
)
577+
save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp)
599578

600579
return times, im_size, metadata
601580

@@ -608,13 +587,6 @@ def setup_video_writer(
608587
cmap: str,
609588
fps: float,
610589
frame_size: tuple[int, int],
611-
video_path: str,
612-
save_dir: str,
613-
timestamp: str,
614-
num_keypoints: int,
615-
cmap: str,
616-
fps: float,
617-
frame_size: tuple[int, int],
618590
):
619591
# Set colors and convert to RGB
620592
cmap_colors = getattr(cc, cmap)
@@ -623,9 +595,7 @@ def setup_video_writer(
623595
# Define output video path
624596
video_path = Path(video_path)
625597
video_name = video_path.stem # filename without extension
626-
output_video_path = (
627-
Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
628-
)
598+
output_video_path = Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
629599

630600
# Get video writer setup
631601
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
@@ -639,7 +609,6 @@ def setup_video_writer(
639609
return colors, vwriter
640610

641611

642-
643612
def draw_pose_and_write(
644613
frame: np.ndarray,
645614
pose: np.ndarray,
@@ -656,9 +625,7 @@ def draw_pose_and_write(
656625

657626
if resize is not None and resize != 1.0:
658627
# Resize the frame
659-
frame = cv2.resize(
660-
frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
661-
)
628+
frame = cv2.resize(frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
662629

663630
# Scale pose coordinates
664631
pose = pose.copy()
@@ -692,7 +659,6 @@ def draw_pose_and_write(
692659
vwriter.write(image=frame)
693660

694661

695-
def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
696662
def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
697663
try:
698664
fourcc = decode_fourcc(cap.get(cv2.CAP_PROP_FOURCC))
@@ -730,9 +696,7 @@ def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
730696
return meta
731697

732698

733-
def save_poses_to_files(
734-
video_path, save_dir, n_individuals, bodyparts, poses, timestamp
735-
):
699+
def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp):
736700
"""
737701
Saves the detected keypoint poses from the video to CSV and HDF5 files.
738702
@@ -778,7 +742,6 @@ def save_poses_to_files(
778742
pose_df.to_csv(csv_save_path, index=False)
779743

780744

781-
782745
def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
783746
# Create numpy array with poses:
784747
max_frame = max(p["frame"] for p in poses)
@@ -791,9 +754,7 @@ def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
791754
if pose.ndim == 2:
792755
pose = pose[np.newaxis, :, :]
793756
padded_pose = np.full(pose_target_shape, np.nan)
794-
slices = tuple(
795-
slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3)
796-
)
757+
slices = tuple(slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3))
797758
padded_pose[slices] = pose[slices]
798759
poses_array[frame] = padded_pose
799760

dlclive/check_install/check_install.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@
88
import argparse
99
import shutil
1010
import urllib
11-
import warnings
1211
import urllib.error
12+
import warnings
1313
from pathlib import Path
1414

1515
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
1616

17-
from dlclive.utils import download_file
1817
from dlclive.benchmark import benchmark_videos
1918
from dlclive.engine import Engine
20-
from dlclive.utils import get_available_backends
19+
from dlclive.utils import download_file, get_available_backends
2120

2221
MODEL_NAME = "superanimal_quadruped"
2322
SNAPSHOT_NAME = "snapshot-700000.pb"
@@ -46,9 +45,7 @@ def run_pytorch_test(video_file: str, display: bool = False):
4645
model_name=TORCH_MODEL,
4746
)
4847
if not TORCH_CONFIG["checkpoint"].exists():
49-
raise FileNotFoundError(
50-
f"Failed to export {TORCH_CONFIG['super_animal']} model"
51-
)
48+
raise FileNotFoundError(f"Failed to export {TORCH_CONFIG['super_animal']} model")
5249
if TORCH_CONFIG["checkpoint"].stat().st_size == 0:
5350
raise ValueError(f"Exported {TORCH_CONFIG['super_animal']} model is empty")
5451
benchmark_videos(
@@ -71,14 +68,10 @@ def run_tensorflow_test(video_file: str, display: bool = False):
7168
if Path(model_dir / SNAPSHOT_NAME).exists():
7269
print("Model already downloaded, using cached version")
7370
else:
74-
print(
75-
"Downloading superanimal_quadruped model from the DeepLabCut Model Zoo..."
76-
)
71+
print("Downloading superanimal_quadruped model from the DeepLabCut Model Zoo...")
7772
download_huggingface_model(MODEL_NAME, str(model_dir))
7873

79-
assert Path(model_dir / SNAPSHOT_NAME).exists(), (
80-
f"Missing model file {model_dir / SNAPSHOT_NAME}"
81-
)
74+
assert Path(model_dir / SNAPSHOT_NAME).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}"
8275

8376
benchmark_videos(
8477
model_path=str(model_dir),
@@ -131,7 +124,7 @@ def main():
131124
url_link = "https://raw.githubusercontent.com/DeepLabCut/DeepLabCut-live/master/check_install/dog_clip.avi"
132125
try:
133126
download_file(url_link, video_file)
134-
except (urllib.error.URLError, IOError) as e:
127+
except (OSError, urllib.error.URLError) as e:
135128
raise RuntimeError(f"Failed to download video file: {e}") from e
136129
else:
137130
print(f"Video file already exists at {video_file}, skipping download.")
@@ -155,9 +148,7 @@ def main():
155148
any_backend_succeeded = True
156149
backend_results["tensorflow"] = ("SUCCESS", None)
157150
else:
158-
warnings.warn(
159-
f"Unrecognized backend {backend}, skipping...", UserWarning
160-
)
151+
warnings.warn(f"Unrecognized backend {backend}, skipping...", UserWarning, stacklevel=2)
161152
except Exception as e:
162153
backend_name = (
163154
"pytorch"
@@ -172,6 +163,7 @@ def main():
172163
f"Error while running test for backend {backend}: {e}. "
173164
"Continuing to test other available backends.",
174165
UserWarning,
166+
stacklevel=2,
175167
)
176168

177169
print("\n---\nBackend test summary:")
@@ -184,9 +176,7 @@ def main():
184176
print(f"{name.capitalize()} error:\n{error}\n")
185177

186178
if not any_backend_succeeded and backend_failures:
187-
failure_messages = "; ".join(
188-
f"{b}: {exc}" for b, exc in backend_failures.items()
189-
)
179+
failure_messages = "; ".join(f"{b}: {exc}" for b, exc in backend_failures.items())
190180
raise RuntimeError(f"All backend tests failed. Details: {failure_messages}")
191181

192182
finally:
@@ -197,7 +187,7 @@ def main():
197187
shutil.rmtree(tmp_dir)
198188
except PermissionError:
199189
warnings.warn(
200-
f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error."
190+
f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error.", stacklevel=2
201191
)
202192

203193

@@ -207,7 +197,8 @@ def main():
207197
print(f"Available backends: {[b.value for b in available_backends]}")
208198
if len(available_backends) == 0:
209199
raise NotImplementedError(
210-
"Neither TensorFlow nor PyTorch is installed. Please install at least one of these frameworks to run the installation test."
200+
"Neither TensorFlow nor PyTorch is installed. "
201+
"Please install at least one of these frameworks to run the installation test."
211202
)
212203

213204
main()

0 commit comments

Comments
 (0)