1313import sys
1414import time
1515import warnings
16- from typing import TYPE_CHECKING
1716from pathlib import Path
18- import argparse
19- import os
17+ from typing import TYPE_CHECKING
18+
2019import colorcet as cc
2120import cv2
2221import numpy as np
2524from pip ._internal .operations import freeze
2625from tqdm import tqdm
2726
28- from dlclive import DLCLive
29- from dlclive import VERSION
27+ from dlclive import VERSION , DLCLive
3028from dlclive .engine import Engine
3129from dlclive .utils import decode_fourcc
3230
@@ -56,20 +54,16 @@ def download_benchmarking_data(
5654 print (f"{ zip_path } already exists. Skipping download." )
5755 else :
5856
59-
6057 def show_progress (count , block_size , total_size ):
6158 pbar .update (block_size )
6259
6360 print (f"Downloading the benchmarking data from { url } ..." )
6461 pbar = tqdm (unit = "B" , total = 0 , position = 0 , desc = "Downloading" )
6562
66- filename , _ = urllib .request .urlretrieve (
67- url , filename = zip_path , reporthook = show_progress
68- )
63+ filename , _ = urllib .request .urlretrieve (url , filename = zip_path , reporthook = show_progress )
6964 pbar .close ()
7065
7166 print (f"Extracting { zip_path } to { target_dir } ..." )
72- with zipfile .ZipFile (zip_path , "r" ) as zip_ref :
7367 with zipfile .ZipFile (zip_path , "r" ) as zip_ref :
7468 zip_ref .extractall (target_dir )
7569
@@ -192,7 +186,6 @@ def benchmark_videos(
192186
193187 for i in range (len (resize )):
194188 print (f"\n Run { i + 1 } / { len (resize )} \n " )
195- print (f"\n Run { i + 1 } / { len (resize )} \n " )
196189
197190 this_inf_times , this_im_size , meta = benchmark (
198191 model_path = model_path ,
@@ -296,7 +289,6 @@ def get_system_info() -> dict:
296289 }
297290
298291
299- def save_inf_times (sys_info , inf_times , im_size , model = None , meta = None , output = None ):
300292def save_inf_times (sys_info , inf_times , im_size , model = None , meta = None , output = None ):
301293 """Save inference time data collected using :function:`benchmark` with system information to a pickle file.
302294 This is primarily used through :function:`benchmark_videos`
@@ -366,7 +358,6 @@ def save_inf_times(sys_info, inf_times, im_size, model=None, meta=None, output=N
366358 return True
367359
368360
369-
370361def benchmark (
371362 model_path : str ,
372363 model_type : str ,
@@ -380,8 +371,6 @@ def benchmark(
380371 dynamic : tuple [bool , float , int ] = (False , 0.5 , 10 ),
381372 n_frames : int = 1000 ,
382373 print_rate : bool = False ,
383- n_frames : int = 1000 ,
384- print_rate : bool = False ,
385374 precision : str = "FP32" ,
386375 display : bool = True ,
387376 pcutoff : float = 0.5 ,
@@ -526,9 +515,7 @@ def benchmark(
526515 frame_index = 0
527516
528517 total_n_frames = cap .get (cv2 .CAP_PROP_FRAME_COUNT )
529- n_frames = int (
530- n_frames if (n_frames > 0 ) and n_frames < total_n_frames else total_n_frames
531- )
518+ n_frames = int (n_frames if (n_frames > 0 ) and n_frames < total_n_frames else total_n_frames )
532519 iterator = range (n_frames ) if print_rate or display else tqdm (range (n_frames ))
533520 for _ in iterator :
534521 ret , frame = cap .read ()
@@ -543,7 +530,6 @@ def benchmark(
543530 start_time = time .perf_counter ()
544531 if frame_index == 0 :
545532 pose = dlc_live .init_inference (frame ) # Loads model
546- pose = dlc_live .init_inference (frame ) # Loads model
547533 else :
548534 pose = dlc_live .get_pose (frame )
549535
@@ -552,9 +538,7 @@ def benchmark(
552538 times .append (inf_time )
553539
554540 if print_rate :
555- print (
556- "Inference rate = {:.3f} FPS" .format (1 / inf_time ), end = "\r " , flush = True
557- )
541+ print (f"Inference rate = { 1 / inf_time :.3f} FPS" , end = "\r " , flush = True )
558542
559543 if save_video :
560544 draw_pose_and_write (
@@ -567,15 +551,12 @@ def benchmark(
567551 display_radius = display_radius ,
568552 draw_keypoint_names = draw_keypoint_names ,
569553 vwriter = vwriter ,
570- vwriter = vwriter ,
571554 )
572555
573556 frame_index += 1
574557
575558 if print_rate :
576- print (
577- "Mean inference rate: {:.3f} FPS" .format (np .mean (1 / np .array (times )[1 :]))
578- )
559+ print (f"Mean inference rate: { np .mean (1 / np .array (times )[1 :]):.3f} FPS" )
579560
580561 metadata = _get_metadata (video_path = video_path , cap = cap , dlc_live = dlc_live )
581562 metadata = _get_metadata (video_path = video_path , cap = cap , dlc_live = dlc_live )
@@ -593,9 +574,7 @@ def benchmark(
593574 else :
594575 individuals = []
595576 n_individuals = len (individuals ) or 1
596- save_poses_to_files (
597- video_path , save_dir , n_individuals , bodyparts , poses , timestamp = timestamp
598- )
577+ save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp = timestamp )
599578
600579 return times , im_size , metadata
601580
@@ -608,13 +587,6 @@ def setup_video_writer(
608587 cmap : str ,
609588 fps : float ,
610589 frame_size : tuple [int , int ],
611- video_path : str ,
612- save_dir : str ,
613- timestamp : str ,
614- num_keypoints : int ,
615- cmap : str ,
616- fps : float ,
617- frame_size : tuple [int , int ],
618590):
619591 # Set colors and convert to RGB
620592 cmap_colors = getattr (cc , cmap )
@@ -623,9 +595,7 @@ def setup_video_writer(
623595 # Define output video path
624596 video_path = Path (video_path )
625597 video_name = video_path .stem # filename without extension
626- output_video_path = (
627- Path (save_dir ) / f"{ video_name } _DLCLIVE_LABELLED_{ timestamp } .mp4"
628- )
598+ output_video_path = Path (save_dir ) / f"{ video_name } _DLCLIVE_LABELLED_{ timestamp } .mp4"
629599
630600 # Get video writer setup
631601 fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
@@ -639,7 +609,6 @@ def setup_video_writer(
639609 return colors , vwriter
640610
641611
642-
643612def draw_pose_and_write (
644613 frame : np .ndarray ,
645614 pose : np .ndarray ,
@@ -656,9 +625,7 @@ def draw_pose_and_write(
656625
657626 if resize is not None and resize != 1.0 :
658627 # Resize the frame
659- frame = cv2 .resize (
660- frame , None , fx = resize , fy = resize , interpolation = cv2 .INTER_LINEAR
661- )
628+ frame = cv2 .resize (frame , None , fx = resize , fy = resize , interpolation = cv2 .INTER_LINEAR )
662629
663630 # Scale pose coordinates
664631 pose = pose .copy ()
@@ -692,7 +659,6 @@ def draw_pose_and_write(
692659 vwriter .write (image = frame )
693660
694661
695- def _get_metadata (video_path : str , cap : cv2 .VideoCapture , dlc_live : DLCLive ):
696662def _get_metadata (video_path : str , cap : cv2 .VideoCapture , dlc_live : DLCLive ):
697663 try :
698664 fourcc = decode_fourcc (cap .get (cv2 .CAP_PROP_FOURCC ))
@@ -730,9 +696,7 @@ def _get_metadata(video_path: str, cap: cv2.VideoCapture, dlc_live: DLCLive):
730696 return meta
731697
732698
733- def save_poses_to_files (
734- video_path , save_dir , n_individuals , bodyparts , poses , timestamp
735- ):
699+ def save_poses_to_files (video_path , save_dir , n_individuals , bodyparts , poses , timestamp ):
736700 """
737701 Saves the detected keypoint poses from the video to CSV and HDF5 files.
738702
@@ -778,7 +742,6 @@ def save_poses_to_files(
778742 pose_df .to_csv (csv_save_path , index = False )
779743
780744
781-
782745def _create_poses_np_array (n_individuals : int , bodyparts : list , poses : list ):
783746 # Create numpy array with poses:
784747 max_frame = max (p ["frame" ] for p in poses )
@@ -791,9 +754,7 @@ def _create_poses_np_array(n_individuals: int, bodyparts: list, poses: list):
791754 if pose .ndim == 2 :
792755 pose = pose [np .newaxis , :, :]
793756 padded_pose = np .full (pose_target_shape , np .nan )
794- slices = tuple (
795- slice (0 , min (pose .shape [i ], pose_target_shape [i ])) for i in range (3 )
796- )
757+ slices = tuple (slice (0 , min (pose .shape [i ], pose_target_shape [i ])) for i in range (3 ))
797758 padded_pose [slices ] = pose [slices ]
798759 poses_array [frame ] = padded_pose
799760
0 commit comments