1010import sys
1111import time
1212import warnings
13+ from typing import TYPE_CHECKING
1314from pathlib import Path
14-
15+ import argparse
16+ import os
1517import colorcet as cc
1618import cv2
1719import numpy as np
2325
2426from dlclive import DLCLive
2527from dlclive import VERSION
26- from dlclive import __file__ as dlcfile
2728from dlclive .engine import Engine
2829from dlclive .utils import decode_fourcc
2930
31+ if TYPE_CHECKING :
32+ try :
33+ import tensorflow
34+ except ImportError :
35+ tensorflow = None
36+
3037
3138def download_benchmarking_data (
3239 target_dir = "." ,
@@ -49,17 +56,20 @@ def download_benchmarking_data(
4956 if os .path .exists (zip_path ):
5057 print (f"{ zip_path } already exists. Skipping download." )
5158 else :
59+
5260 def show_progress (count , block_size , total_size ):
5361 pbar .update (block_size )
5462
5563 print (f"Downloading the benchmarking data from { url } ..." )
5664 pbar = tqdm (unit = "B" , total = 0 , position = 0 , desc = "Downloading" )
5765
58- filename , _ = urllib .request .urlretrieve (url , filename = zip_path , reporthook = show_progress )
66+ filename , _ = urllib .request .urlretrieve (
67+ url , filename = zip_path , reporthook = show_progress
68+ )
5969 pbar .close ()
6070
6171 print (f"Extracting { zip_path } to { target_dir } ..." )
62- with zipfile .ZipFile (zip_path , 'r' ) as zip_ref :
72+ with zipfile .ZipFile (zip_path , "r" ) as zip_ref :
6373 zip_ref .extractall (target_dir )
6474
6575
@@ -81,6 +91,7 @@ def benchmark_videos(
8191 cmap = "bmy" ,
8292 save_poses = False ,
8393 save_video = False ,
94+ single_animal = False ,
8495):
8596 """Analyze videos using DeepLabCut-live exported models.
8697 Analyze multiple videos and/or multiple options for the size of the video
@@ -168,7 +179,7 @@ def benchmark_videos(
168179 im_size_out = []
169180
170181 for i in range (len (resize )):
171- print (f"\n Run { i + 1 } / { len (resize )} \n " )
182+ print (f"\n Run { i + 1 } / { len (resize )} \n " )
172183
173184 this_inf_times , this_im_size , meta = benchmark (
174185 model_path = model_path ,
@@ -188,6 +199,7 @@ def benchmark_videos(
188199 save_poses = save_poses ,
189200 save_video = save_video ,
190201 save_dir = output ,
202+ single_animal = single_animal ,
191203 )
192204
193205 inf_times .append (this_inf_times )
@@ -257,7 +269,7 @@ def get_system_info() -> dict:
257269 dev_type = "GPU"
258270 dev = [torch .cuda .get_device_name (torch .cuda .current_device ())]
259271 else :
260- from cpuinfo import get_cpu_info
272+ from cpuinfo import get_cpu_info # noqa: F401
261273
262274 dev_type = "CPU"
263275 dev = get_cpu_info ()
@@ -275,9 +287,7 @@ def get_system_info() -> dict:
275287 }
276288
277289
278- def save_inf_times (
279- sys_info , inf_times , im_size , model = None , meta = None , output = None
280- ):
290+ def save_inf_times (sys_info , inf_times , im_size , model = None , meta = None , output = None ):
281291 """Save inference time data collected using :function:`benchmark` with system information to a pickle file.
282292 This is primarily used through :function:`benchmark_videos`
283293
@@ -346,6 +356,7 @@ def save_inf_times(
346356
347357 return True
348358
359+
349360def benchmark (
350361 model_path : str ,
351362 model_type : str ,
@@ -357,8 +368,8 @@ def benchmark(
357368 single_animal : bool = True ,
358369 cropping : list [int ] | None = None ,
359370 dynamic : tuple [bool , float , int ] = (False , 0.5 , 10 ),
360- n_frames : int = 1000 ,
361- print_rate : bool = False ,
371+ n_frames : int = 1000 ,
372+ print_rate : bool = False ,
362373 precision : str = "FP32" ,
363374 display : bool = True ,
364375 pcutoff : float = 0.5 ,
@@ -434,7 +445,10 @@ def benchmark(
434445 if not cap .isOpened ():
435446 print (f"Error: Could not open video file { video_path } " )
436447 return
437- im_size = (int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH )), int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT )))
448+ im_size = (
449+ int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH )),
450+ int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT )),
451+ )
438452
439453 if pixels is not None :
440454 resize = np .sqrt (pixels / (im_size [0 ] * im_size [1 ]))
@@ -492,9 +506,7 @@ def benchmark(
492506
493507 total_n_frames = cap .get (cv2 .CAP_PROP_FRAME_COUNT )
494508 n_frames = int (
495- n_frames
496- if (n_frames > 0 ) and n_frames < total_n_frames
497- else total_n_frames
509+ n_frames if (n_frames > 0 ) and n_frames < total_n_frames else total_n_frames
498510 )
499511 iterator = range (n_frames ) if print_rate or display else tqdm (range (n_frames ))
500512 for _ in iterator :
@@ -510,7 +522,7 @@ def benchmark(
510522
511523 start_time = time .perf_counter ()
512524 if frame_index == 0 :
513- pose = dlc_live .init_inference (frame ) # Loads model
525+ pose = dlc_live .init_inference (frame ) # Loads model
514526 else :
515527 pose = dlc_live .get_pose (frame )
516528
@@ -519,7 +531,9 @@ def benchmark(
519531 times .append (inf_time )
520532
521533 if print_rate :
522- print ("Inference rate = {:.3f} FPS" .format (1 / inf_time ), end = "\r " , flush = True )
534+ print (
535+ "Inference rate = {:.3f} FPS" .format (1 / inf_time ), end = "\r " , flush = True
536+ )
523537
524538 if save_video :
525539 draw_pose_and_write (
@@ -531,19 +545,17 @@ def benchmark(
531545 pcutoff = pcutoff ,
532546 display_radius = display_radius ,
533547 draw_keypoint_names = draw_keypoint_names ,
534- vwriter = vwriter
548+ vwriter = vwriter ,
535549 )
536550
537551 frame_index += 1
538552
539553 if print_rate :
540- print ("Mean inference rate: {:.3f} FPS" .format (np .mean (1 / np .array (times )[1 :])))
554+ print (
555+ "Mean inference rate: {:.3f} FPS" .format (np .mean (1 / np .array (times )[1 :]))
556+ )
541557
542- metadata = _get_metadata (
543- video_path = video_path ,
544- cap = cap ,
545- dlc_live = dlc_live
546- )
558+ metadata = _get_metadata (video_path = video_path , cap = cap , dlc_live = dlc_live )
547559
548560 cap .release ()
549561
@@ -558,19 +570,21 @@ def benchmark(
558570 else :
559571 individuals = []
560572 n_individuals = len (individuals ) or 1
561- save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp = timestamp )
573+ save_poses_to_files (
574+ video_path , save_dir , n_individuals , bodyparts , poses , timestamp = timestamp
575+ )
562576
563577 return times , im_size , metadata
564578
565579
566580def setup_video_writer (
567- video_path :str ,
568- save_dir :str ,
569- timestamp :str ,
570- num_keypoints :int ,
571- cmap :str ,
572- fps :float ,
573- frame_size :tuple [int , int ],
581+ video_path : str ,
582+ save_dir : str ,
583+ timestamp : str ,
584+ num_keypoints : int ,
585+ cmap : str ,
586+ fps : float ,
587+ frame_size : tuple [int , int ],
574588):
575589 # Set colors and convert to RGB
576590 cmap_colors = getattr (cc , cmap )
@@ -582,7 +596,9 @@ def setup_video_writer(
582596 # Define output video path
583597 video_path = Path (video_path )
584598 video_name = video_path .stem # filename without extension
585- output_video_path = Path (save_dir ) / f"{ video_name } _DLCLIVE_LABELLED_{ timestamp } .mp4"
599+ output_video_path = (
600+ Path (save_dir ) / f"{ video_name } _DLCLIVE_LABELLED_{ timestamp } .mp4"
601+ )
586602
587603 # Get video writer setup
588604 fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
@@ -595,6 +611,7 @@ def setup_video_writer(
595611
596612 return colors , vwriter
597613
614+
598615def draw_pose_and_write (
599616 frame : np .ndarray ,
600617 pose : np .ndarray ,
@@ -611,7 +628,9 @@ def draw_pose_and_write(
611628
612629 if resize is not None and resize != 1.0 :
613630 # Resize the frame
614- frame = cv2 .resize (frame , None , fx = resize , fy = resize , interpolation = cv2 .INTER_LINEAR )
631+ frame = cv2 .resize (
632+ frame , None , fx = resize , fy = resize , interpolation = cv2 .INTER_LINEAR
633+ )
615634
616635 # Scale pose coordinates
617636 pose = pose .copy ()
@@ -642,15 +661,10 @@ def draw_pose_and_write(
642661 lineType = cv2 .LINE_AA ,
643662 )
644663
645-
646664 vwriter .write (image = frame )
647665
648666
649- def _get_metadata (
650- video_path : str ,
651- cap : cv2 .VideoCapture ,
652- dlc_live : DLCLive
653- ):
667+ def _get_metadata (video_path : str , cap : cv2 .VideoCapture , dlc_live : DLCLive ):
654668 try :
655669 fourcc = decode_fourcc (cap .get (cv2 .CAP_PROP_FOURCC ))
656670 except Exception :
@@ -687,7 +701,9 @@ def _get_metadata(
687701 return meta
688702
689703
690- def save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp ):
704+ def save_poses_to_files (
705+ video_path , save_dir , n_individuals , bodyparts , poses , timestamp
706+ ):
691707 """
692708 Saves the detected keypoint poses from the video to CSV and HDF5 files.
693709
@@ -708,7 +724,7 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
708724 -------
709725 None
710726 """
711- import pandas as pd
727+ import pandas as pd # noqa: F401
712728
713729 base_filename = Path (video_path ).stem
714730 save_dir = Path (save_dir )
@@ -725,14 +741,16 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
725741 else :
726742 individuals = [f"individual_{ i } " for i in range (n_individuals )]
727743 pdindex = pd .MultiIndex .from_product (
728- [individuals , bodyparts , ["x" , "y" , "likelihood" ]], names = ["individuals" , "bodyparts" , "coords" ]
744+ [individuals , bodyparts , ["x" , "y" , "likelihood" ]],
745+ names = ["individuals" , "bodyparts" , "coords" ],
729746 )
730747
731748 pose_df = pd .DataFrame (flattened_poses , columns = pdindex )
732749
733750 pose_df .to_hdf (h5_save_path , key = "df_with_missing" , mode = "w" )
734751 pose_df .to_csv (csv_save_path , index = False )
735752
753+
736754def _create_poses_np_array (n_individuals : int , bodyparts : list , poses : list ):
737755 # Create numpy array with poses:
738756 max_frame = max (p ["frame" ] for p in poses )
@@ -745,17 +763,15 @@ def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
745763 if pose .ndim == 2 :
746764 pose = pose [np .newaxis , :, :]
747765 padded_pose = np .full (pose_target_shape , np .nan )
748- slices = tuple (slice (0 , min (pose .shape [i ], pose_target_shape [i ])) for i in range (3 ))
766+ slices = tuple (
767+ slice (0 , min (pose .shape [i ], pose_target_shape [i ])) for i in range (3 )
768+ )
749769 padded_pose [slices ] = pose [slices ]
750770 poses_array [frame ] = padded_pose
751771
752772 return poses_array
753773
754774
755- import argparse
756- import os
757-
758-
759775def main ():
760776 """Provides a command line interface to benchmark_videos function."""
761777 parser = argparse .ArgumentParser (
0 commit comments