Skip to content

Commit 89ba52e

Browse files
committed
Add single_animal option to benchmark_videos
Expose a single_animal flag on benchmark_videos (default False) and forward it to the underlying analysis call. This allows benchmarking to run in single-animal mode when using DeepLabCut-live exported models.
1 parent 68b3902 commit 89ba52e

1 file changed

Lines changed: 66 additions & 19 deletions

File tree

dlclive/benchmark.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
import sys
1414
import time
1515
import warnings
16-
from pathlib import Path
1716
from typing import TYPE_CHECKING
18-
17+
from pathlib import Path
18+
import argparse
19+
import os
1920
import colorcet as cc
2021
import cv2
2122
import numpy as np
@@ -24,14 +25,16 @@
2425
from pip._internal.operations import freeze
2526
from tqdm import tqdm
2627

28+
from dlclive import DLCLive
29+
from dlclive import VERSION
2730
from dlclive.engine import Engine
2831
from dlclive.utils import decode_fourcc
2932

30-
from .dlclive import DLCLive
31-
from .version import VERSION
32-
3333
if TYPE_CHECKING:
34-
import tensorflow
34+
try:
35+
import tensorflow
36+
except ImportError:
37+
tensorflow = None
3538

3639

3740
def download_benchmarking_data(
@@ -56,16 +59,20 @@ def download_benchmarking_data(
5659
print(f"{zip_path} already exists. Skipping download.")
5760
else:
5861

62+
5963
def show_progress(count, block_size, total_size):
6064
pbar.update(block_size)
6165

6266
print(f"Downloading the benchmarking data from {url} ...")
6367
pbar = tqdm(unit="B", total=0, position=0, desc="Downloading")
6468

65-
filename, _ = urllib.request.urlretrieve(url, filename=zip_path, reporthook=show_progress)
69+
filename, _ = urllib.request.urlretrieve(
70+
url, filename=zip_path, reporthook=show_progress
71+
)
6672
pbar.close()
6773

6874
print(f"Extracting {zip_path} to {target_dir} ...")
75+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
6976
with zipfile.ZipFile(zip_path, "r") as zip_ref:
7077
zip_ref.extractall(target_dir)
7178

@@ -88,6 +95,7 @@ def benchmark_videos(
8895
cmap="bmy",
8996
save_poses=False,
9097
save_video=False,
98+
single_animal=False,
9199
):
92100
"""Analyze videos using DeepLabCut-live exported models.
93101
Analyze multiple videos and/or multiple options for the size of the video
@@ -187,6 +195,7 @@ def benchmark_videos(
187195

188196
for i in range(len(resize)):
189197
print(f"\nRun {i + 1} / {len(resize)}\n")
198+
print(f"\nRun {i + 1} / {len(resize)}\n")
190199

191200
this_inf_times, this_im_size, meta = benchmark(
192201
model_path=model_path,
@@ -206,6 +215,7 @@ def benchmark_videos(
206215
save_poses=save_poses,
207216
save_video=save_video,
208217
save_dir=output,
218+
single_animal=single_animal,
209219
)
210220

211221
inf_times.append(this_inf_times)
@@ -271,7 +281,7 @@ def get_system_info() -> dict:
271281
dev_type = "GPU"
272282
dev = [torch.cuda.get_device_name(torch.cuda.current_device())]
273283
else:
274-
from cpuinfo import get_cpu_info
284+
from cpuinfo import get_cpu_info # noqa: F401
275285

276286
dev_type = "CPU"
277287
dev = get_cpu_info()
@@ -289,6 +299,7 @@ def get_system_info() -> dict:
289299
}
290300

291301

302+
def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=None):
292303
def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=None):
293304
"""Save inference time data collected using :function:`benchmark` with system information to a pickle file.
294305
This is primarily used through :function:`benchmark_videos`
@@ -358,6 +369,7 @@ def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=N
358369
return True
359370

360371

372+
361373
def benchmark(
362374
model_path: str,
363375
model_type: str,
@@ -371,6 +383,8 @@ def benchmark(
371383
dynamic: tuple[bool, float, int] = (False, 0.5, 10),
372384
n_frames: int = 1000,
373385
print_rate: bool = False,
386+
n_frames: int = 1000,
387+
print_rate: bool = False,
374388
precision: str = "FP32",
375389
display: bool = True,
376390
pcutoff: float = 0.5,
@@ -455,7 +469,10 @@ def benchmark(
455469
if not cap.isOpened():
456470
print(f"Error: Could not open video file {video_path}")
457471
return
458-
im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
472+
im_size = (
473+
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
474+
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
475+
)
459476

460477
if pixels is not None:
461478
resize = np.sqrt(pixels / (im_size[0] * im_size[1]))
@@ -512,7 +529,9 @@ def benchmark(
512529
frame_index = 0
513530

514531
total_n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
515-
n_frames = int(n_frames if (n_frames > 0) and n_frames < total_n_frames else total_n_frames)
532+
n_frames = int(
533+
n_frames if (n_frames > 0) and n_frames < total_n_frames else total_n_frames
534+
)
516535
iterator = range(n_frames) if print_rate or display else tqdm(range(n_frames))
517536
for _ in iterator:
518537
ret, frame = cap.read()
@@ -527,6 +546,7 @@ def benchmark(
527546
start_time = time.perf_counter()
528547
if frame_index == 0:
529548
pose = dlc_live.init_inference(frame) # Loads model
549+
pose = dlc_live.init_inference(frame) # Loads model
530550
else:
531551
pose = dlc_live.get_pose(frame)
532552

@@ -535,7 +555,9 @@ def benchmark(
535555
times.append(inf_time)
536556

537557
if print_rate:
538-
print(f"Inference rate = {1 / inf_time:.3f} FPS", end="\r", flush=True)
558+
print(
559+
"Inference rate = {:.3f} FPS".format(1 / inf_time), end="\r", flush=True
560+
)
539561

540562
if save_video:
541563
draw_pose_and_write(
@@ -548,14 +570,18 @@ def benchmark(
548570
display_radius=display_radius,
549571
draw_keypoint_names=draw_keypoint_names,
550572
vwriter=vwriter,
573+
vwriter=vwriter,
551574
)
552575

553576
frame_index += 1
554577

555578
if print_rate:
556-
print(f"Mean inference rate: {np.mean(1 / np.array(times)[1:]):.3f} FPS")
579+
print(
580+
"Mean inference rate: {:.3f} FPS".format(np.mean(1 / np.array(times)[1:]))
581+
)
557582

558583
metadata = _get_metadata(video_path=video_path, cap=cap, dlc_live=dlc_live)
584+
metadata = _get_metadata(video_path=video_path, cap=cap, dlc_live=dlc_live)
559585

560586
cap.release()
561587

@@ -570,7 +596,9 @@ def benchmark(
570596
else:
571597
individuals = []
572598
n_individuals = len(individuals) or 1
573-
save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp)
599+
save_poses_to_files(
600+
video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp
601+
)
574602

575603
return times, im_size, metadata
576604

@@ -583,6 +611,13 @@ def setup_video_writer(
583611
cmap: str,
584612
fps: float,
585613
frame_size: tuple[int, int],
614+
video_path: str,
615+
save_dir: str,
616+
timestamp: str,
617+
num_keypoints: int,
618+
cmap: str,
619+
fps: float,
620+
frame_size: tuple[int, int],
586621
):
587622
# Set colors and convert to RGB
588623
cmap_colors = getattr(cc, cmap)
@@ -591,7 +626,9 @@ def setup_video_writer(
591626
# Define output video path
592627
video_path = Path(video_path)
593628
video_name = video_path.stem # filename without extension
594-
output_video_path = Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
629+
output_video_path = (
630+
Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
631+
)
595632

596633
# Get video writer setup
597634
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
@@ -605,6 +642,7 @@ def setup_video_writer(
605642
return colors, vwriter
606643

607644

645+
608646
def draw_pose_and_write(
609647
frame: np.ndarray,
610648
pose: np.ndarray,
@@ -621,7 +659,9 @@ def draw_pose_and_write(
621659

622660
if resize is not None and resize != 1.0:
623661
# Resize the frame
624-
frame = cv2.resize(frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
662+
frame = cv2.resize(
663+
frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
664+
)
625665

626666
# Scale pose coordinates
627667
pose = pose.copy()
@@ -655,6 +695,7 @@ def draw_pose_and_write(
655695
vwriter.write(image=frame)
656696

657697

698+
def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
658699
def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
659700
try:
660701
fourcc = decode_fourcc(cap.get(cv2.CAP_PROP_FOURCC))
@@ -692,7 +733,9 @@ def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
692733
return meta
693734

694735

695-
def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp):
736+
def save_poses_to_files(
737+
video_path, save_dir, n_individuals, bodyparts, poses, timestamp
738+
):
696739
"""
697740
Saves the detected keypoint poses from the video to CSV and HDF5 files.
698741
@@ -713,7 +756,7 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
713756
-------
714757
None
715758
"""
716-
import pandas as pd # noqa E402
759+
import pandas as pd # noqa: F401
717760

718761
base_filename = Path(video_path).stem
719762
save_dir = Path(save_dir)
@@ -728,7 +771,8 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
728771
else:
729772
individuals = [f"individual_{i}" for i in range(n_individuals)]
730773
pdindex = pd.MultiIndex.from_product(
731-
[individuals, bodyparts, ["x", "y", "likelihood"]], names=["individuals", "bodyparts", "coords"]
774+
[individuals, bodyparts, ["x", "y", "likelihood"]],
775+
names=["individuals", "bodyparts", "coords"],
732776
)
733777

734778
pose_df = pd.DataFrame(flattened_poses, columns=pdindex)
@@ -737,6 +781,7 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
737781
pose_df.to_csv(csv_save_path, index=False)
738782

739783

784+
740785
def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
741786
# Create numpy array with poses:
742787
max_frame = max(p["frame"] for p in poses)
@@ -749,7 +794,9 @@ def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
749794
if pose.ndim == 2:
750795
pose = pose[np.newaxis, :, :]
751796
padded_pose = np.full(pose_target_shape, np.nan)
752-
slices = tuple(slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3))
797+
slices = tuple(
798+
slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3)
799+
)
753800
padded_pose[slices] = pose[slices]
754801
poses_array[frame] = padded_pose
755802

0 commit comments

Comments
 (0)