Skip to content

Commit a2a1a25

Browse files
committed
Add single_animal option to benchmark_videos
Expose a single_animal flag on benchmark_videos (default False) and forward it to the underlying analysis call. This allows benchmarking to run in single-animal mode when using DeepLabCut-live exported models.
1 parent 5e66526 commit a2a1a25

1 file changed

Lines changed: 64 additions & 48 deletions

File tree

dlclive/benchmark.py

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import sys
1111
import time
1212
import warnings
13+
from typing import TYPE_CHECKING
1314
from pathlib import Path
14-
15+
import argparse
16+
import os
1517
import colorcet as cc
1618
import cv2
1719
import numpy as np
@@ -23,10 +25,15 @@
2325

2426
from dlclive import DLCLive
2527
from dlclive import VERSION
26-
from dlclive import __file__ as dlcfile
2728
from dlclive.engine import Engine
2829
from dlclive.utils import decode_fourcc
2930

31+
if TYPE_CHECKING:
32+
try:
33+
import tensorflow
34+
except ImportError:
35+
tensorflow = None
36+
3037

3138
def download_benchmarking_data(
3239
target_dir=".",
@@ -49,17 +56,20 @@ def download_benchmarking_data(
4956
if os.path.exists(zip_path):
5057
print(f"{zip_path} already exists. Skipping download.")
5158
else:
59+
5260
def show_progress(count, block_size, total_size):
5361
pbar.update(block_size)
5462

5563
print(f"Downloading the benchmarking data from {url} ...")
5664
pbar = tqdm(unit="B", total=0, position=0, desc="Downloading")
5765

58-
filename, _ = urllib.request.urlretrieve(url, filename=zip_path, reporthook=show_progress)
66+
filename, _ = urllib.request.urlretrieve(
67+
url, filename=zip_path, reporthook=show_progress
68+
)
5969
pbar.close()
6070

6171
print(f"Extracting {zip_path} to {target_dir} ...")
62-
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
72+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
6373
zip_ref.extractall(target_dir)
6474

6575

@@ -81,6 +91,7 @@ def benchmark_videos(
8191
cmap="bmy",
8292
save_poses=False,
8393
save_video=False,
94+
single_animal=False,
8495
):
8596
"""Analyze videos using DeepLabCut-live exported models.
8697
Analyze multiple videos and/or multiple options for the size of the video
@@ -168,7 +179,7 @@ def benchmark_videos(
168179
im_size_out = []
169180

170181
for i in range(len(resize)):
171-
print(f"\nRun {i+1} / {len(resize)}\n")
182+
print(f"\nRun {i + 1} / {len(resize)}\n")
172183

173184
this_inf_times, this_im_size, meta = benchmark(
174185
model_path=model_path,
@@ -188,6 +199,7 @@ def benchmark_videos(
188199
save_poses=save_poses,
189200
save_video=save_video,
190201
save_dir=output,
202+
single_animal=single_animal,
191203
)
192204

193205
inf_times.append(this_inf_times)
@@ -257,7 +269,7 @@ def get_system_info() -> dict:
257269
dev_type = "GPU"
258270
dev = [torch.cuda.get_device_name(torch.cuda.current_device())]
259271
else:
260-
from cpuinfo import get_cpu_info
272+
from cpuinfo import get_cpu_info # noqa: F401
261273

262274
dev_type = "CPU"
263275
dev = get_cpu_info()
@@ -275,9 +287,7 @@ def get_system_info() -> dict:
275287
}
276288

277289

278-
def save_inf_times(
279-
sys_info, inf_times, im_size, model=None, meta=None, output=None
280-
):
290+
def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=None):
281291
"""Save inference time data collected using :function:`benchmark` with system information to a pickle file.
282292
This is primarily used through :function:`benchmark_videos`
283293
@@ -346,6 +356,7 @@ def save_inf_times(
346356

347357
return True
348358

359+
349360
def benchmark(
350361
model_path: str,
351362
model_type: str,
@@ -357,8 +368,8 @@ def benchmark(
357368
single_animal: bool = True,
358369
cropping: list[int] | None = None,
359370
dynamic: tuple[bool, float, int] = (False, 0.5, 10),
360-
n_frames: int =1000,
361-
print_rate: bool=False,
371+
n_frames: int = 1000,
372+
print_rate: bool = False,
362373
precision: str = "FP32",
363374
display: bool = True,
364375
pcutoff: float = 0.5,
@@ -434,7 +445,10 @@ def benchmark(
434445
if not cap.isOpened():
435446
print(f"Error: Could not open video file {video_path}")
436447
return
437-
im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
448+
im_size = (
449+
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
450+
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
451+
)
438452

439453
if pixels is not None:
440454
resize = np.sqrt(pixels / (im_size[0] * im_size[1]))
@@ -492,9 +506,7 @@ def benchmark(
492506

493507
total_n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
494508
n_frames = int(
495-
n_frames
496-
if (n_frames > 0) and n_frames < total_n_frames
497-
else total_n_frames
509+
n_frames if (n_frames > 0) and n_frames < total_n_frames else total_n_frames
498510
)
499511
iterator = range(n_frames) if print_rate or display else tqdm(range(n_frames))
500512
for _ in iterator:
@@ -510,7 +522,7 @@ def benchmark(
510522

511523
start_time = time.perf_counter()
512524
if frame_index == 0:
513-
pose = dlc_live.init_inference(frame) # Loads model
525+
pose = dlc_live.init_inference(frame) # Loads model
514526
else:
515527
pose = dlc_live.get_pose(frame)
516528

@@ -519,7 +531,9 @@ def benchmark(
519531
times.append(inf_time)
520532

521533
if print_rate:
522-
print("Inference rate = {:.3f} FPS".format(1 / inf_time), end="\r", flush=True)
534+
print(
535+
"Inference rate = {:.3f} FPS".format(1 / inf_time), end="\r", flush=True
536+
)
523537

524538
if save_video:
525539
draw_pose_and_write(
@@ -531,19 +545,17 @@ def benchmark(
531545
pcutoff=pcutoff,
532546
display_radius=display_radius,
533547
draw_keypoint_names=draw_keypoint_names,
534-
vwriter=vwriter
548+
vwriter=vwriter,
535549
)
536550

537551
frame_index += 1
538552

539553
if print_rate:
540-
print("Mean inference rate: {:.3f} FPS".format(np.mean(1 / np.array(times)[1:])))
554+
print(
555+
"Mean inference rate: {:.3f} FPS".format(np.mean(1 / np.array(times)[1:]))
556+
)
541557

542-
metadata = _get_metadata(
543-
video_path=video_path,
544-
cap=cap,
545-
dlc_live=dlc_live
546-
)
558+
metadata = _get_metadata(video_path=video_path, cap=cap, dlc_live=dlc_live)
547559

548560
cap.release()
549561

@@ -558,19 +570,21 @@ def benchmark(
558570
else:
559571
individuals = []
560572
n_individuals = len(individuals) or 1
561-
save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp)
573+
save_poses_to_files(
574+
video_path, save_dir, n_individuals, bodyparts, poses, timestamp=timestamp
575+
)
562576

563577
return times, im_size, metadata
564578

565579

566580
def setup_video_writer(
567-
video_path:str,
568-
save_dir:str,
569-
timestamp:str,
570-
num_keypoints:int,
571-
cmap:str,
572-
fps:float,
573-
frame_size:tuple[int, int],
581+
video_path: str,
582+
save_dir: str,
583+
timestamp: str,
584+
num_keypoints: int,
585+
cmap: str,
586+
fps: float,
587+
frame_size: tuple[int, int],
574588
):
575589
# Set colors and convert to RGB
576590
cmap_colors = getattr(cc, cmap)
@@ -582,7 +596,9 @@ def setup_video_writer(
582596
# Define output video path
583597
video_path = Path(video_path)
584598
video_name = video_path.stem # filename without extension
585-
output_video_path = Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
599+
output_video_path = (
600+
Path(save_dir) / f"{video_name}_DLCLIVE_LABELLED_{timestamp}.mp4"
601+
)
586602

587603
# Get video writer setup
588604
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
@@ -595,6 +611,7 @@ def setup_video_writer(
595611

596612
return colors, vwriter
597613

614+
598615
def draw_pose_and_write(
599616
frame: np.ndarray,
600617
pose: np.ndarray,
@@ -611,7 +628,9 @@ def draw_pose_and_write(
611628

612629
if resize is not None and resize != 1.0:
613630
# Resize the frame
614-
frame = cv2.resize(frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
631+
frame = cv2.resize(
632+
frame, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
633+
)
615634

616635
# Scale pose coordinates
617636
pose = pose.copy()
@@ -642,15 +661,10 @@ def draw_pose_and_write(
642661
lineType=cv2.LINE_AA,
643662
)
644663

645-
646664
vwriter.write(image=frame)
647665

648666

649-
def _get_metadata(
650-
video_path: str,
651-
cap: cv2.VideoCapture,
652-
dlc_live: DLCLive
653-
):
667+
def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
654668
try:
655669
fourcc = decode_fourcc(cap.get(cv2.CAP_PROP_FOURCC))
656670
except Exception:
@@ -687,7 +701,9 @@ def _get_metadata(
687701
return meta
688702

689703

690-
def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, timestamp):
704+
def save_poses_to_files(
705+
video_path, save_dir, n_individuals, bodyparts, poses, timestamp
706+
):
691707
"""
692708
Saves the detected keypoint poses from the video to CSV and HDF5 files.
693709
@@ -708,7 +724,7 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
708724
-------
709725
None
710726
"""
711-
import pandas as pd
727+
import pandas as pd # noqa: F401
712728

713729
base_filename = Path(video_path).stem
714730
save_dir = Path(save_dir)
@@ -725,14 +741,16 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
725741
else:
726742
individuals = [f"individual_{i}" for i in range(n_individuals)]
727743
pdindex = pd.MultiIndex.from_product(
728-
[individuals, bodyparts, ["x", "y", "likelihood"]], names=["individuals", "bodyparts", "coords"]
744+
[individuals, bodyparts, ["x", "y", "likelihood"]],
745+
names=["individuals", "bodyparts", "coords"],
729746
)
730747

731748
pose_df = pd.DataFrame(flattened_poses, columns=pdindex)
732749

733750
pose_df.to_hdf(h5_save_path, key="df_with_missing", mode="w")
734751
pose_df.to_csv(csv_save_path, index=False)
735752

753+
736754
def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
737755
# Create numpy array with poses:
738756
max_frame = max(p["frame"] for p in poses)
@@ -745,17 +763,15 @@ def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
745763
if pose.ndim == 2:
746764
pose = pose[np.newaxis, :, :]
747765
padded_pose = np.full(pose_target_shape, np.nan)
748-
slices = tuple(slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3))
766+
slices = tuple(
767+
slice(0, min(pose.shape[i], pose_target_shape[i])) for i in range(3)
768+
)
749769
padded_pose[slices] = pose[slices]
750770
poses_array[frame] = padded_pose
751771

752772
return poses_array
753773

754774

755-
import argparse
756-
import os
757-
758-
759775
def main():
760776
"""Provides a command line interface to benchmark_videos function."""
761777
parser = argparse.ArgumentParser(

0 commit comments

Comments
 (0)