1313import sys
1414import time
1515import warnings
16- from pathlib import Path
1716from typing import TYPE_CHECKING
18-
17+ from pathlib import Path
18+ import argparse
19+ import os
1920import colorcet as cc
2021import cv2
2122import numpy as np
2425from pip ._internal .operations import freeze
2526from tqdm import tqdm
2627
28+ from dlclive import DLCLive
29+ from dlclive import VERSION
2730from dlclive .engine import Engine
2831from dlclive .utils import decode_fourcc
2932
30- from .dlclive import DLCLive
31- from .version import VERSION
32-
3333if TYPE_CHECKING :
34- import tensorflow
34+ try :
35+ import tensorflow
36+ except ImportError :
37+ tensorflow = None
3538
3639
3740def download_benchmarking_data (
@@ -56,16 +59,20 @@ def download_benchmarking_data(
5659 print (f"{ zip_path } already exists. Skipping download." )
5760 else :
5861
62+
5963 def show_progress (count , block_size , total_size ):
6064 pbar .update (block_size )
6165
6266 print (f"Downloading the benchmarking data from { url } ..." )
6367 pbar = tqdm (unit = "B" , total = 0 , position = 0 , desc = "Downloading" )
6468
65- filename , _ = urllib .request .urlretrieve (url , filename = zip_path , reporthook = show_progress )
69+ filename , _ = urllib .request .urlretrieve (
70+ url , filename = zip_path , reporthook = show_progress
71+ )
6672 pbar .close ()
6773
6874 print (f"Extracting { zip_path } to { target_dir } ..." )
75+ with zipfile .ZipFile (zip_path , "r" ) as zip_ref :
6976 with zipfile .ZipFile (zip_path , "r" ) as zip_ref :
7077 zip_ref .extractall (target_dir )
7178
@@ -88,6 +95,7 @@ def benchmark_videos(
8895 cmap = "bmy" ,
8996 save_poses = False ,
9097 save_video = False ,
98+ single_animal = False ,
9199):
92100 """Analyze videos using DeepLabCut-live exported models.
93101 Analyze multiple videos and/or multiple options for the size of the video
@@ -187,6 +195,7 @@ def benchmark_videos(
187195
188196 for i in range (len (resize )):
189197 print (f"\n Run { i + 1 } / { len (resize )} \n " )
198+ print (f"\n Run { i + 1 } / { len (resize )} \n " )
190199
191200 this_inf_times , this_im_size , meta = benchmark (
192201 model_path = model_path ,
@@ -206,6 +215,7 @@ def benchmark_videos(
206215 save_poses = save_poses ,
207216 save_video = save_video ,
208217 save_dir = output ,
218+ single_animal = single_animal ,
209219 )
210220
211221 inf_times .append (this_inf_times )
@@ -271,7 +281,7 @@ def get_system_info() -> dict:
271281 dev_type = "GPU"
272282 dev = [torch .cuda .get_device_name (torch .cuda .current_device ())]
273283 else :
274- from cpuinfo import get_cpu_info
284+ from cpuinfo import get_cpu_info # noqa: F401
275285
276286 dev_type = "CPU"
277287 dev = get_cpu_info ()
@@ -289,6 +299,7 @@ def get_system_info() -> dict:
289299 }
290300
291301
302+ def save_inf_times (sys_info , inf_times , im_size , model = None , meta = None , output = None ):
292303def save_inf_times (sys_info , inf_times , im_size , model = None , meta = None , output = None ):
293304 """Save inference time data collected using :function:`benchmark` with system information to a pickle file.
294305 This is primarily used through :function:`benchmark_videos`
@@ -358,6 +369,7 @@ def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=N
358369 return True
359370
360371
372+
361373def benchmark (
362374 model_path : str ,
363375 model_type : str ,
@@ -371,6 +383,8 @@ def benchmark(
371383 dynamic : tuple [bool , float , int ] = (False , 0.5 , 10 ),
372384 n_frames : int = 1000 ,
373385 print_rate : bool = False ,
386+ n_frames : int = 1000 ,
387+ print_rate : bool = False ,
374388 precision : str = "FP32" ,
375389 display : bool = True ,
376390 pcutoff : float = 0.5 ,
@@ -455,7 +469,10 @@ def benchmark(
455469 if not cap .isOpened ():
456470 print (f"Error: Could not open video file { video_path } " )
457471 return
458- im_size = (int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH )), int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT )))
472+ im_size = (
473+ int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH )),
474+ int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT )),
475+ )
459476
460477 if pixels is not None :
461478 resize = np .sqrt (pixels / (im_size [0 ] * im_size [1 ]))
@@ -512,7 +529,9 @@ def benchmark(
512529 frame_index = 0
513530
514531 total_n_frames = cap .get (cv2 .CAP_PROP_FRAME_COUNT )
515- n_frames = int (n_frames if (n_frames > 0 ) and n_frames < total_n_frames else total_n_frames )
532+ n_frames = int (
533+ n_frames if (n_frames > 0 ) and n_frames < total_n_frames else total_n_frames
534+ )
516535 iterator = range (n_frames ) if print_rate or display else tqdm (range (n_frames ))
517536 for _ in iterator :
518537 ret , frame = cap .read ()
@@ -527,6 +546,7 @@ def benchmark(
527546 start_time = time .perf_counter ()
528547 if frame_index == 0 :
529548 pose = dlc_live .init_inference (frame ) # Loads model
549+ pose = dlc_live .init_inference (frame ) # Loads model
530550 else :
531551 pose = dlc_live .get_pose (frame )
532552
@@ -535,7 +555,9 @@ def benchmark(
535555 times .append (inf_time )
536556
537557 if print_rate :
538- print (f"Inference rate = { 1 / inf_time :.3f} FPS" , end = "\r " , flush = True )
558+ print (
559+ "Inference rate = {:.3f} FPS" .format (1 / inf_time ), end = "\r " , flush = True
560+ )
539561
540562 if save_video :
541563 draw_pose_and_write (
@@ -548,14 +570,18 @@ def benchmark(
548570 display_radius = display_radius ,
549571 draw_keypoint_names = draw_keypoint_names ,
550572 vwriter = vwriter ,
573+ vwriter = vwriter ,
551574 )
552575
553576 frame_index += 1
554577
555578 if print_rate :
556- print (f"Mean inference rate: { np .mean (1 / np .array (times )[1 :]):.3f} FPS" )
579+ print (
580+ "Mean inference rate: {:.3f} FPS" .format (np .mean (1 / np .array (times )[1 :]))
581+ )
557582
558583 metadata = _get_metadata (video_path = video_path , cap = cap , dlc_live = dlc_live )
584+ metadata = _get_metadata (video_path = video_path , cap = cap , dlc_live = dlc_live )
559585
560586 cap .release ()
561587
@@ -570,7 +596,9 @@ def benchmark(
570596 else :
571597 individuals = []
572598 n_individuals = len (individuals ) or 1
573- save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp = timestamp )
599+ save_poses_to_files (
600+ video_path , save_dir , n_individuals , bodyparts , poses , timestamp = timestamp
601+ )
574602
575603 return times , im_size , metadata
576604
@@ -583,6 +611,13 @@ def setup_video_writer(
583611 cmap : str ,
584612 fps : float ,
585613 frame_size : tuple [int , int ],
614+ video_path : str ,
615+ save_dir : str ,
616+ timestamp : str ,
617+ num_keypoints : int ,
618+ cmap : str ,
619+ fps : float ,
620+ frame_size : tuple [int , int ],
586621):
587622 # Set colors and convert to RGB
588623 cmap_colors = getattr (cc , cmap )
@@ -591,7 +626,9 @@ def setup_video_writer(
591626 # Define output video path
592627 video_path = Path (video_path )
593628 video_name = video_path .stem # filename without extension
594- output_video_path = Path (save_dir ) / f"{ video_name } _DLCLIVE_LABELLED_{ timestamp } .mp4"
629+ output_video_path = (
630+ Path (save_dir ) / f"{ video_name } _DLCLIVE_LABELLED_{ timestamp } .mp4"
631+ )
595632
596633 # Get video writer setup
597634 fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
@@ -605,6 +642,7 @@ def setup_video_writer(
605642 return colors , vwriter
606643
607644
645+
608646def draw_pose_and_write (
609647 frame : np .ndarray ,
610648 pose : np .ndarray ,
@@ -621,7 +659,9 @@ def draw_pose_and_write(
621659
622660 if resize is not None and resize != 1.0 :
623661 # Resize the frame
624- frame = cv2 .resize (frame , None , fx = resize , fy = resize , interpolation = cv2 .INTER_LINEAR )
662+ frame = cv2 .resize (
663+ frame , None , fx = resize , fy = resize , interpolation = cv2 .INTER_LINEAR
664+ )
625665
626666 # Scale pose coordinates
627667 pose = pose .copy ()
@@ -655,6 +695,7 @@ def draw_pose_and_write(
655695 vwriter .write (image = frame )
656696
657697
698+ def _get_metadata (video_path : str , cap : cv2 .VideoCapture , dlc_live : DLCLive ):
658699def _get_metadata (video_path : str , cap : cv2 .VideoCapture , dlc_live : DLCLive ):
659700 try :
660701 fourcc = decode_fourcc (cap .get (cv2 .CAP_PROP_FOURCC ))
@@ -692,7 +733,9 @@ def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
692733 return meta
693734
694735
695- def save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp ):
736+ def save_poses_to_files (
737+ video_path , save_dir , n_individuals , bodyparts , poses , timestamp
738+ ):
696739 """
697740 Saves the detected keypoint poses from the video to CSV and HDF5 files.
698741
@@ -713,7 +756,7 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
713756 -------
714757 None
715758 """
716- import pandas as pd # noqa E402
759+ import pandas as pd # noqa: F401
717760
718761 base_filename = Path (video_path ).stem
719762 save_dir = Path (save_dir )
@@ -728,7 +771,8 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
728771 else :
729772 individuals = [f"individual_{ i } " for i in range (n_individuals )]
730773 pdindex = pd .MultiIndex .from_product (
731- [individuals , bodyparts , ["x" , "y" , "likelihood" ]], names = ["individuals" , "bodyparts" , "coords" ]
774+ [individuals , bodyparts , ["x" , "y" , "likelihood" ]],
775+ names = ["individuals" , "bodyparts" , "coords" ],
732776 )
733777
734778 pose_df = pd .DataFrame (flattened_poses , columns = pdindex )
@@ -737,6 +781,7 @@ def save_poses_to_files(video_path, save_dir, n_individuals, bodyparts, poses, t
737781 pose_df .to_csv (csv_save_path , index = False )
738782
739783
784+
740785def _create_poses_np_array (n_individuals : int , bodyparts : list , poses : list ):
741786 # Create numpy array with poses:
742787 max_frame = max (p ["frame" ] for p in poses )
@@ -749,7 +794,9 @@ def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
749794 if pose .ndim == 2 :
750795 pose = pose [np .newaxis , :, :]
751796 padded_pose = np .full (pose_target_shape , np .nan )
752- slices = tuple (slice (0 , min (pose .shape [i ], pose_target_shape [i ])) for i in range (3 ))
797+ slices = tuple (
798+ slice (0 , min (pose .shape [i ], pose_target_shape [i ])) for i in range (3 )
799+ )
753800 padded_pose [slices ] = pose [slices ]
754801 poses_array [frame ] = padded_pose
755802
0 commit comments