1010#
1111
1212
13- # NOTE - DUPLICATED @C-Achard 2026-26-01 : Copied from the original DeepLabCut codebase
13+ # NOTE - DUPLICATED @C-Achard 2026-01-26 : Copied from the original DeepLabCut codebase
1414# from deeplabcut/core/inferenceutils.py
1515from __future__ import annotations
1616
@@ -66,7 +66,9 @@ def __init__(self, j1, j2, affinity=1):
6666 self ._length = sqrt ((j1 .pos [0 ] - j2 .pos [0 ]) ** 2 + (j1 .pos [1 ] - j2 .pos [1 ]) ** 2 )
6767
6868 def __repr__ (self ):
69- return f"Link { self .idx } , affinity={ self .affinity :.2f} , length={ self .length :.2f} "
69+ return (
70+ f"Link { self .idx } , affinity={ self .affinity :.2f} , length={ self .length :.2f} "
71+ )
7072
7173 @property
7274 def confidence (self ):
@@ -264,7 +266,10 @@ def __init__(
264266 self .max_overlap = max_overlap
265267 self ._has_identity = "identity" in self [0 ]
266268 if identity_only and not self ._has_identity :
267- warnings .warn ("The network was not trained with identity; setting `identity_only` to False." , stacklevel = 2 )
269+ warnings .warn (
270+ "The network was not trained with identity; setting `identity_only` to False." ,
271+ stacklevel = 2 ,
272+ )
268273 self .identity_only = identity_only & self ._has_identity
269274 self .nan_policy = nan_policy
270275 self .force_fusion = force_fusion
@@ -345,15 +350,19 @@ def calibrate(self, train_data_file):
345350 pass
346351 n_bpts = len (df .columns .get_level_values ("bodyparts" ).unique ())
347352 if n_bpts == 1 :
348- warnings .warn ("There is only one keypoint; skipping calibration..." , stacklevel = 2 )
353+ warnings .warn (
354+ "There is only one keypoint; skipping calibration..." , stacklevel = 2
355+ )
349356 return
350357
351358 xy = df .to_numpy ().reshape ((- 1 , n_bpts , 2 ))
352359 frac_valid = np .mean (~ np .isnan (xy ), axis = (1 , 2 ))
353360 # Only keeps skeletons that are more than 90% complete
354361 xy = xy [frac_valid >= 0.9 ]
355362 if not xy .size :
356- warnings .warn ("No complete poses were found. Skipping calibration..." , stacklevel = 2 )
363+ warnings .warn (
364+ "No complete poses were found. Skipping calibration..." , stacklevel = 2
365+ )
357366 return
358367
359368 # TODO Normalize dists by longest length?
@@ -369,9 +378,14 @@ def calibrate(self, train_data_file):
369378 self .safe_edge = True
370379 except np .linalg .LinAlgError :
371380 # Covariance matrix estimation fails due to numerical singularities
372- warnings .warn ("The assembler could not be robustly calibrated. Continuing without it..." , stacklevel = 2 )
381+ warnings .warn (
382+ "The assembler could not be robustly calibrated. Continuing without it..." ,
383+ stacklevel = 2 ,
384+ )
373385
374- def calc_assembly_mahalanobis_dist (self , assembly , return_proba = False , nan_policy = "little" ):
386+ def calc_assembly_mahalanobis_dist (
387+ self , assembly , return_proba = False , nan_policy = "little"
388+ ):
375389 if self ._kde is None :
376390 raise ValueError ("Assembler should be calibrated first with training data." )
377391
@@ -425,7 +439,9 @@ def _flatten_detections(data_dict):
425439 ids = [np .ones (len (arr ), dtype = int ) * - 1 for arr in confidence ]
426440 else :
427441 ids = [arr .argmax (axis = 1 ) for arr in ids ]
428- for i , (coords , conf , id_ ) in enumerate (zip (coordinates , confidence , ids , strict = False )):
442+ for i , (coords , conf , id_ ) in enumerate (
443+ zip (coordinates , confidence , ids , strict = False )
444+ ):
429445 if not np .any (coords ):
430446 continue
431447 for xy , p , g in zip (coords , conf , id_ , strict = False ):
@@ -450,7 +466,9 @@ def extract_best_links(self, joints_dict, costs, trees=None):
450466 aff [np .isnan (aff )] = 0
451467
452468 if trees :
453- vecs = np .vstack ([[* det_s .pos , * det_t .pos ] for det_s in dets_s for det_t in dets_t ])
469+ vecs = np .vstack (
470+ [[* det_s .pos , * det_t .pos ] for det_s in dets_s for det_t in dets_t ]
471+ )
454472 dists = []
455473 for n , tree in enumerate (trees , start = 1 ):
456474 d , _ = tree .query (vecs )
@@ -459,8 +477,15 @@ def extract_best_links(self, joints_dict, costs, trees=None):
459477 aff *= w .reshape (aff .shape )
460478
461479 if self .greedy :
462- conf = np .asarray ([[det_s .confidence * det_t .confidence for det_t in dets_t ] for det_s in dets_s ])
463- rows , cols = np .where ((conf >= self .pcutoff * self .pcutoff ) & (aff >= self .min_affinity ))
480+ conf = np .asarray (
481+ [
482+ [det_s .confidence * det_t .confidence for det_t in dets_t ]
483+ for det_s in dets_s
484+ ]
485+ )
486+ rows , cols = np .where (
487+ (conf >= self .pcutoff * self .pcutoff ) & (aff >= self .min_affinity )
488+ )
464489 candidates = sorted (
465490 zip (rows , cols , aff [rows , cols ], lengths [rows , cols ], strict = False ),
466491 key = lambda x : x [2 ],
@@ -476,14 +501,18 @@ def extract_best_links(self, joints_dict, costs, trees=None):
476501 if len (i_seen ) == self .max_n_individuals :
477502 break
478503 else : # Optimal keypoint pairing
479- inds_s = sorted (range (len (dets_s )), key = lambda x : dets_s [x ].confidence , reverse = True )[
480- : self .max_n_individuals
504+ inds_s = sorted (
505+ range (len (dets_s )), key = lambda x : dets_s [x ].confidence , reverse = True
506+ )[: self .max_n_individuals ]
507+ inds_t = sorted (
508+ range (len (dets_t )), key = lambda x : dets_t [x ].confidence , reverse = True
509+ )[: self .max_n_individuals ]
510+ keep_s = [
511+ ind for ind in inds_s if dets_s [ind ].confidence >= self .pcutoff
481512 ]
482- inds_t = sorted ( range ( len ( dets_t )), key = lambda x : dets_t [ x ]. confidence , reverse = True ) [
483- : self .max_n_individuals
513+ keep_t = [
514+ ind for ind in inds_t if dets_t [ ind ]. confidence >= self .pcutoff
484515 ]
485- keep_s = [ind for ind in inds_s if dets_s [ind ].confidence >= self .pcutoff ]
486- keep_t = [ind for ind in inds_t if dets_t [ind ].confidence >= self .pcutoff ]
487516 aff = aff [np .ix_ (keep_s , keep_t )]
488517 rows , cols = linear_sum_assignment (aff , maximize = True )
489518 for row , col in zip (rows , cols , strict = False ):
@@ -522,7 +551,9 @@ def push_to_stack(i):
522551 if new_ind in assembled :
523552 continue
524553 if safe_edge :
525- d_old = self .calc_assembly_mahalanobis_dist (assembly , nan_policy = nan_policy )
554+ d_old = self .calc_assembly_mahalanobis_dist (
555+ assembly , nan_policy = nan_policy
556+ )
526557 success = assembly .add_link (best , store_dict = True )
527558 if not success :
528559 assembly ._dict = dict ()
@@ -575,7 +606,9 @@ def build_assemblies(self, links):
575606 continue
576607 assembly = Assembly (self .n_multibodyparts )
577608 assembly .add_link (link )
578- self ._fill_assembly (assembly , lookup , assembled , self .safe_edge , self .nan_policy )
609+ self ._fill_assembly (
610+ assembly , lookup , assembled , self .safe_edge , self .nan_policy
611+ )
579612 for assembly_link in assembly ._links :
580613 i , j = assembly_link .idx
581614 lookup [i ].pop (j )
@@ -587,7 +620,10 @@ def build_assemblies(self, links):
587620 n_extra = len (assemblies ) - self .max_n_individuals
588621 if n_extra > 0 :
589622 if self .safe_edge :
590- ds_old = [self .calc_assembly_mahalanobis_dist (assembly ) for assembly in assemblies ]
623+ ds_old = [
624+ self .calc_assembly_mahalanobis_dist (assembly )
625+ for assembly in assemblies
626+ ]
591627 while len (assemblies ) > self .max_n_individuals :
592628 ds = []
593629 for i , j in itertools .combinations (range (len (assemblies )), 2 ):
@@ -719,7 +755,10 @@ def _assemble(self, data_dict, ind_frame):
719755 for _ , group in groups :
720756 ass = Assembly (self .n_multibodyparts )
721757 for joint in sorted (group , key = lambda x : x .confidence , reverse = True ):
722- if joint .confidence >= self .pcutoff and joint .label < self .n_multibodyparts :
758+ if (
759+ joint .confidence >= self .pcutoff
760+ and joint .label < self .n_multibodyparts
761+ ):
723762 ass .add_joint (joint )
724763 if len (ass ):
725764 assemblies .append (ass )
@@ -748,15 +787,21 @@ def _assemble(self, data_dict, ind_frame):
748787 assembled .update (assembled_ )
749788
750789 # Remove invalid assemblies
751- discarded = set (joint for joint in joints if joint .idx not in assembled and np .isfinite (joint .confidence ))
790+ discarded = set (
791+ joint
792+ for joint in joints
793+ if joint .idx not in assembled and np .isfinite (joint .confidence )
794+ )
752795 for assembly in assemblies [::- 1 ]:
753796 if 0 < assembly .n_links < self .min_n_links or not len (assembly ):
754797 for link in assembly ._links :
755798 discarded .update ((link .j1 , link .j2 ))
756799 assemblies .remove (assembly )
757800 if 0 < self .max_overlap < 1 : # Non-maximum pose suppression
758801 if self ._kde is not None :
759- scores = [- self .calc_assembly_mahalanobis_dist (ass ) for ass in assemblies ]
802+ scores = [
803+ - self .calc_assembly_mahalanobis_dist (ass ) for ass in assemblies
804+ ]
760805 else :
761806 scores = [ass ._affinity for ass in assemblies ]
762807 lst = list (zip (scores , assemblies , strict = False ))
@@ -825,7 +870,9 @@ def wrapped(i):
825870 n_frames = len (self .metadata ["imnames" ])
826871 with multiprocessing .Pool (n_processes ) as p :
827872 with tqdm (total = n_frames ) as pbar :
828- for i , (assemblies , unique ) in p .imap_unordered (wrapped , range (n_frames ), chunksize = chunk_size ):
873+ for i , (assemblies , unique ) in p .imap_unordered (
874+ wrapped , range (n_frames ), chunksize = chunk_size
875+ ):
829876 if assemblies :
830877 self .assemblies [i ] = assemblies
831878 if unique is not None :
@@ -844,7 +891,9 @@ def parse_metadata(data):
844891 params ["joint_names" ] = data ["metadata" ]["all_joints_names" ]
845892 params ["num_joints" ] = len (params ["joint_names" ])
846893 params ["paf_graph" ] = data ["metadata" ]["PAFgraph" ]
847- params ["paf" ] = data ["metadata" ].get ("PAFinds" , np .arange (len (params ["joint_names" ])))
894+ params ["paf" ] = data ["metadata" ].get (
895+ "PAFinds" , np .arange (len (params ["joint_names" ]))
896+ )
848897 params ["bpts" ] = params ["ibpts" ] = range (params ["num_joints" ])
849898 params ["imnames" ] = [fn for fn in list (data ) if fn != "metadata" ]
850899 return params
@@ -934,7 +983,11 @@ def calc_object_keypoint_similarity(
934983 else :
935984 oks = []
936985 xy_preds = [xy_pred ]
937- combos = (pair for l in range (len (symmetric_kpts )) for pair in itertools .combinations (symmetric_kpts , l + 1 ))
986+ combos = (
987+ pair
988+ for l in range (len (symmetric_kpts ))
989+ for pair in itertools .combinations (symmetric_kpts , l + 1 )
990+ )
938991 for pairs in combos :
939992 # Swap corresponding keypoints
940993 tmp = xy_pred .copy ()
@@ -971,7 +1024,9 @@ def match_assemblies(
9711024 num_ground_truth = len (ground_truth )
9721025
9731026 # Sort predictions by score
974- inds_pred = np .argsort ([ins .affinity if ins .n_links else ins .confidence for ins in predictions ])[::- 1 ]
1027+ inds_pred = np .argsort (
1028+ [ins .affinity if ins .n_links else ins .confidence for ins in predictions ]
1029+ )[::- 1 ]
9751030 predictions = np .asarray (predictions )[inds_pred ]
9761031
9771032 # indices of unmatched ground truth assemblies
@@ -1078,7 +1133,9 @@ def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)):
10781133 raise ValueError (f"Invalid criterion { criterion } ." )
10791134
10801135 if len (qs ) != 2 :
1081- raise ValueError ("Two percentiles (for lower and upper bounds) should be given." )
1136+ raise ValueError (
1137+ "Two percentiles (for lower and upper bounds) should be given."
1138+ )
10821139
10831140 tuples = []
10841141 for frame_ind , assemblies in dict_of_assemblies .items ():
@@ -1182,7 +1239,9 @@ def evaluate_assembly_greedy(
11821239 oks = np .asarray ([match .oks for match in all_matched ])[sorted_pred_indices ]
11831240
11841241 # Compute prediction and recall
1185- p , r = _compute_precision_and_recall (total_gt_assemblies , oks , oks_t , recall_thresholds )
1242+ p , r = _compute_precision_and_recall (
1243+ total_gt_assemblies , oks , oks_t , recall_thresholds
1244+ )
11861245 precisions .append (p )
11871246 recalls .append (r )
11881247
@@ -1255,7 +1314,9 @@ def evaluate_assembly(
12551314 precisions = []
12561315 recalls = []
12571316 for t in oks_thresholds :
1258- p , r = _compute_precision_and_recall (total_gt_assemblies , oks , t , recall_thresholds )
1317+ p , r = _compute_precision_and_recall (
1318+ total_gt_assemblies , oks , t , recall_thresholds
1319+ )
12591320 precisions .append (p )
12601321 recalls .append (r )
12611322
0 commit comments