Skip to content

Commit 03fe25a

Browse files
committed
Fix dates in comments
1 parent 679919d commit 03fe25a

5 files changed

Lines changed: 121 additions & 43 deletions

File tree

dlclive/core/inferenceutils.py

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#
1111

1212

13-
# NOTE - DUPLICATED @C-Achard 2026-26-01: Copied from the original DeepLabCut codebase
13+
# NOTE - DUPLICATED @C-Achard 2026-01-26: Copied from the original DeepLabCut codebase
1414
# from deeplabcut/core/inferenceutils.py
1515
from __future__ import annotations
1616

@@ -66,7 +66,9 @@ def __init__(self, j1, j2, affinity=1):
6666
self._length = sqrt((j1.pos[0] - j2.pos[0]) ** 2 + (j1.pos[1] - j2.pos[1]) ** 2)
6767

6868
def __repr__(self):
69-
return f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}"
69+
return (
70+
f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}"
71+
)
7072

7173
@property
7274
def confidence(self):
@@ -264,7 +266,10 @@ def __init__(
264266
self.max_overlap = max_overlap
265267
self._has_identity = "identity" in self[0]
266268
if identity_only and not self._has_identity:
267-
warnings.warn("The network was not trained with identity; setting `identity_only` to False.", stacklevel=2)
269+
warnings.warn(
270+
"The network was not trained with identity; setting `identity_only` to False.",
271+
stacklevel=2,
272+
)
268273
self.identity_only = identity_only & self._has_identity
269274
self.nan_policy = nan_policy
270275
self.force_fusion = force_fusion
@@ -345,15 +350,19 @@ def calibrate(self, train_data_file):
345350
pass
346351
n_bpts = len(df.columns.get_level_values("bodyparts").unique())
347352
if n_bpts == 1:
348-
warnings.warn("There is only one keypoint; skipping calibration...", stacklevel=2)
353+
warnings.warn(
354+
"There is only one keypoint; skipping calibration...", stacklevel=2
355+
)
349356
return
350357

351358
xy = df.to_numpy().reshape((-1, n_bpts, 2))
352359
frac_valid = np.mean(~np.isnan(xy), axis=(1, 2))
353360
# Only keeps skeletons that are more than 90% complete
354361
xy = xy[frac_valid >= 0.9]
355362
if not xy.size:
356-
warnings.warn("No complete poses were found. Skipping calibration...", stacklevel=2)
363+
warnings.warn(
364+
"No complete poses were found. Skipping calibration...", stacklevel=2
365+
)
357366
return
358367

359368
# TODO Normalize dists by longest length?
@@ -369,9 +378,14 @@ def calibrate(self, train_data_file):
369378
self.safe_edge = True
370379
except np.linalg.LinAlgError:
371380
# Covariance matrix estimation fails due to numerical singularities
372-
warnings.warn("The assembler could not be robustly calibrated. Continuing without it...", stacklevel=2)
381+
warnings.warn(
382+
"The assembler could not be robustly calibrated. Continuing without it...",
383+
stacklevel=2,
384+
)
373385

374-
def calc_assembly_mahalanobis_dist(self, assembly, return_proba=False, nan_policy="little"):
386+
def calc_assembly_mahalanobis_dist(
387+
self, assembly, return_proba=False, nan_policy="little"
388+
):
375389
if self._kde is None:
376390
raise ValueError("Assembler should be calibrated first with training data.")
377391

@@ -425,7 +439,9 @@ def _flatten_detections(data_dict):
425439
ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence]
426440
else:
427441
ids = [arr.argmax(axis=1) for arr in ids]
428-
for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids, strict=False)):
442+
for i, (coords, conf, id_) in enumerate(
443+
zip(coordinates, confidence, ids, strict=False)
444+
):
429445
if not np.any(coords):
430446
continue
431447
for xy, p, g in zip(coords, conf, id_, strict=False):
@@ -450,7 +466,9 @@ def extract_best_links(self, joints_dict, costs, trees=None):
450466
aff[np.isnan(aff)] = 0
451467

452468
if trees:
453-
vecs = np.vstack([[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t])
469+
vecs = np.vstack(
470+
[[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t]
471+
)
454472
dists = []
455473
for n, tree in enumerate(trees, start=1):
456474
d, _ = tree.query(vecs)
@@ -459,8 +477,15 @@ def extract_best_links(self, joints_dict, costs, trees=None):
459477
aff *= w.reshape(aff.shape)
460478

461479
if self.greedy:
462-
conf = np.asarray([[det_s.confidence * det_t.confidence for det_t in dets_t] for det_s in dets_s])
463-
rows, cols = np.where((conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity))
480+
conf = np.asarray(
481+
[
482+
[det_s.confidence * det_t.confidence for det_t in dets_t]
483+
for det_s in dets_s
484+
]
485+
)
486+
rows, cols = np.where(
487+
(conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity)
488+
)
464489
candidates = sorted(
465490
zip(rows, cols, aff[rows, cols], lengths[rows, cols], strict=False),
466491
key=lambda x: x[2],
@@ -476,14 +501,18 @@ def extract_best_links(self, joints_dict, costs, trees=None):
476501
if len(i_seen) == self.max_n_individuals:
477502
break
478503
else: # Optimal keypoint pairing
479-
inds_s = sorted(range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True)[
480-
: self.max_n_individuals
504+
inds_s = sorted(
505+
range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True
506+
)[: self.max_n_individuals]
507+
inds_t = sorted(
508+
range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True
509+
)[: self.max_n_individuals]
510+
keep_s = [
511+
ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff
481512
]
482-
inds_t = sorted(range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True)[
483-
: self.max_n_individuals
513+
keep_t = [
514+
ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff
484515
]
485-
keep_s = [ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff]
486-
keep_t = [ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff]
487516
aff = aff[np.ix_(keep_s, keep_t)]
488517
rows, cols = linear_sum_assignment(aff, maximize=True)
489518
for row, col in zip(rows, cols, strict=False):
@@ -522,7 +551,9 @@ def push_to_stack(i):
522551
if new_ind in assembled:
523552
continue
524553
if safe_edge:
525-
d_old = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy)
554+
d_old = self.calc_assembly_mahalanobis_dist(
555+
assembly, nan_policy=nan_policy
556+
)
526557
success = assembly.add_link(best, store_dict=True)
527558
if not success:
528559
assembly._dict = dict()
@@ -575,7 +606,9 @@ def build_assemblies(self, links):
575606
continue
576607
assembly = Assembly(self.n_multibodyparts)
577608
assembly.add_link(link)
578-
self._fill_assembly(assembly, lookup, assembled, self.safe_edge, self.nan_policy)
609+
self._fill_assembly(
610+
assembly, lookup, assembled, self.safe_edge, self.nan_policy
611+
)
579612
for assembly_link in assembly._links:
580613
i, j = assembly_link.idx
581614
lookup[i].pop(j)
@@ -587,7 +620,10 @@ def build_assemblies(self, links):
587620
n_extra = len(assemblies) - self.max_n_individuals
588621
if n_extra > 0:
589622
if self.safe_edge:
590-
ds_old = [self.calc_assembly_mahalanobis_dist(assembly) for assembly in assemblies]
623+
ds_old = [
624+
self.calc_assembly_mahalanobis_dist(assembly)
625+
for assembly in assemblies
626+
]
591627
while len(assemblies) > self.max_n_individuals:
592628
ds = []
593629
for i, j in itertools.combinations(range(len(assemblies)), 2):
@@ -719,7 +755,10 @@ def _assemble(self, data_dict, ind_frame):
719755
for _, group in groups:
720756
ass = Assembly(self.n_multibodyparts)
721757
for joint in sorted(group, key=lambda x: x.confidence, reverse=True):
722-
if joint.confidence >= self.pcutoff and joint.label < self.n_multibodyparts:
758+
if (
759+
joint.confidence >= self.pcutoff
760+
and joint.label < self.n_multibodyparts
761+
):
723762
ass.add_joint(joint)
724763
if len(ass):
725764
assemblies.append(ass)
@@ -748,15 +787,21 @@ def _assemble(self, data_dict, ind_frame):
748787
assembled.update(assembled_)
749788

750789
# Remove invalid assemblies
751-
discarded = set(joint for joint in joints if joint.idx not in assembled and np.isfinite(joint.confidence))
790+
discarded = set(
791+
joint
792+
for joint in joints
793+
if joint.idx not in assembled and np.isfinite(joint.confidence)
794+
)
752795
for assembly in assemblies[::-1]:
753796
if 0 < assembly.n_links < self.min_n_links or not len(assembly):
754797
for link in assembly._links:
755798
discarded.update((link.j1, link.j2))
756799
assemblies.remove(assembly)
757800
if 0 < self.max_overlap < 1: # Non-maximum pose suppression
758801
if self._kde is not None:
759-
scores = [-self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies]
802+
scores = [
803+
-self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies
804+
]
760805
else:
761806
scores = [ass._affinity for ass in assemblies]
762807
lst = list(zip(scores, assemblies, strict=False))
@@ -825,7 +870,9 @@ def wrapped(i):
825870
n_frames = len(self.metadata["imnames"])
826871
with multiprocessing.Pool(n_processes) as p:
827872
with tqdm(total=n_frames) as pbar:
828-
for i, (assemblies, unique) in p.imap_unordered(wrapped, range(n_frames), chunksize=chunk_size):
873+
for i, (assemblies, unique) in p.imap_unordered(
874+
wrapped, range(n_frames), chunksize=chunk_size
875+
):
829876
if assemblies:
830877
self.assemblies[i] = assemblies
831878
if unique is not None:
@@ -844,7 +891,9 @@ def parse_metadata(data):
844891
params["joint_names"] = data["metadata"]["all_joints_names"]
845892
params["num_joints"] = len(params["joint_names"])
846893
params["paf_graph"] = data["metadata"]["PAFgraph"]
847-
params["paf"] = data["metadata"].get("PAFinds", np.arange(len(params["joint_names"])))
894+
params["paf"] = data["metadata"].get(
895+
"PAFinds", np.arange(len(params["joint_names"]))
896+
)
848897
params["bpts"] = params["ibpts"] = range(params["num_joints"])
849898
params["imnames"] = [fn for fn in list(data) if fn != "metadata"]
850899
return params
@@ -934,7 +983,11 @@ def calc_object_keypoint_similarity(
934983
else:
935984
oks = []
936985
xy_preds = [xy_pred]
937-
combos = (pair for l in range(len(symmetric_kpts)) for pair in itertools.combinations(symmetric_kpts, l + 1))
986+
combos = (
987+
pair
988+
for l in range(len(symmetric_kpts))
989+
for pair in itertools.combinations(symmetric_kpts, l + 1)
990+
)
938991
for pairs in combos:
939992
# Swap corresponding keypoints
940993
tmp = xy_pred.copy()
@@ -971,7 +1024,9 @@ def match_assemblies(
9711024
num_ground_truth = len(ground_truth)
9721025

9731026
# Sort predictions by score
974-
inds_pred = np.argsort([ins.affinity if ins.n_links else ins.confidence for ins in predictions])[::-1]
1027+
inds_pred = np.argsort(
1028+
[ins.affinity if ins.n_links else ins.confidence for ins in predictions]
1029+
)[::-1]
9751030
predictions = np.asarray(predictions)[inds_pred]
9761031

9771032
# indices of unmatched ground truth assemblies
@@ -1078,7 +1133,9 @@ def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)):
10781133
raise ValueError(f"Invalid criterion {criterion}.")
10791134

10801135
if len(qs) != 2:
1081-
raise ValueError("Two percentiles (for lower and upper bounds) should be given.")
1136+
raise ValueError(
1137+
"Two percentiles (for lower and upper bounds) should be given."
1138+
)
10821139

10831140
tuples = []
10841141
for frame_ind, assemblies in dict_of_assemblies.items():
@@ -1182,7 +1239,9 @@ def evaluate_assembly_greedy(
11821239
oks = np.asarray([match.oks for match in all_matched])[sorted_pred_indices]
11831240

11841241
# Compute prediction and recall
1185-
p, r = _compute_precision_and_recall(total_gt_assemblies, oks, oks_t, recall_thresholds)
1242+
p, r = _compute_precision_and_recall(
1243+
total_gt_assemblies, oks, oks_t, recall_thresholds
1244+
)
11861245
precisions.append(p)
11871246
recalls.append(r)
11881247

@@ -1255,7 +1314,9 @@ def evaluate_assembly(
12551314
precisions = []
12561315
recalls = []
12571316
for t in oks_thresholds:
1258-
p, r = _compute_precision_and_recall(total_gt_assemblies, oks, t, recall_thresholds)
1317+
p, r = _compute_precision_and_recall(
1318+
total_gt_assemblies, oks, t, recall_thresholds
1319+
)
12591320
precisions.append(p)
12601321
recalls.append(r)
12611322

dlclive/modelzoo/resolve_config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
For instance, "num_bodyparts x 2" is replaced with the number of bodyparts multiplied by 2.
44
"""
55

6-
# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase
6+
# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
77
# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py
88
import copy
99

@@ -99,7 +99,9 @@ def get_updated_value(variable: str) -> int | list[int]:
9999
else:
100100
raise ValueError(f"Unknown operator for variable: {variable}")
101101

102-
raise ValueError(f"Found {variable} in the configuration file, but cannot parse it.")
102+
raise ValueError(
103+
f"Found {variable} in the configuration file, but cannot parse it."
104+
)
103105

104106
updated_values = {
105107
"num_bodyparts": num_bodyparts,
@@ -125,7 +127,10 @@ def get_updated_value(variable: str) -> int | list[int]:
125127
backbone_output_channels,
126128
**kwargs,
127129
)
128-
elif isinstance(config[k], str) and config[k].strip().split(" ")[0] in updated_values.keys():
130+
elif (
131+
isinstance(config[k], str)
132+
and config[k].strip().split(" ")[0] in updated_values.keys()
133+
):
129134
config[k] = get_updated_value(config[k])
130135

131136
return config

dlclive/modelzoo/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def read_config_as_dict(config_path: str | Path) -> dict:
6969
return cfg
7070

7171

72-
# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase
72+
# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
7373
# from deeplabcut/pose_estimation_pytorch/config/make_pose_config.py
7474
def add_metadata(
7575
project_config: dict,
@@ -98,7 +98,7 @@ def add_metadata(
9898
return config
9999

100100

101-
# NOTE - DUPLICATED @deruyter92 2026-23-01: Copied from the original DeepLabCut codebase
101+
# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
102102
# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py
103103
def load_super_animal_config(
104104
super_animal: str,

dlclive/pose_estimation_pytorch/dynamic_cropping.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#
99
# Licensed under GNU Lesser General Public License v3.0
1010

11-
# NOTE DUPLICATED @C-Achard 2026-26-01: Duplication between this file
11+
# NOTE DUPLICATED @C-Achard 2026-01-26: Duplication between this file
1212
# and deeplabcut/pose_estimation_pytorch/runners/dynamic_cropping.py
1313
# NOTE Testing already exists at deeplabcut/tests/pose_estimation_pytorch/runners/test_dynamic_cropper.py
1414
"""Modules to dynamically crop individuals out of videos to improve video analysis"""
@@ -82,7 +82,9 @@ def crop(self, image: torch.Tensor) -> torch.Tensor:
8282
height.
8383
"""
8484
if len(image) != 1:
85-
raise RuntimeError(f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})")
85+
raise RuntimeError(
86+
f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})"
87+
)
8688

8789
if self._shape is None:
8890
self._shape = image.shape[3], image.shape[2]
@@ -307,7 +309,9 @@ def crop(self, image: torch.Tensor) -> torch.Tensor:
307309
`crop` was previously called with an image of a different W or H.
308310
"""
309311
if len(image) != 1:
310-
raise RuntimeError(f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})")
312+
raise RuntimeError(
313+
f"DynamicCropper can only be used with batch size 1 (found image shape: {image.shape})"
314+
)
311315

312316
if self._shape is None:
313317
self._shape = image.shape[3], image.shape[2]
@@ -394,7 +398,9 @@ def update(self, pose: torch.Tensor) -> torch.Tensor:
394398

395399
return pose
396400

397-
def _prepare_bounding_box(self, x1: int, y1: int, x2: int, y2: int) -> tuple[int, int, int, int]:
401+
def _prepare_bounding_box(
402+
self, x1: int, y1: int, x2: int, y2: int
403+
) -> tuple[int, int, int, int]:
398404
"""Prepares the bounding box for cropping.
399405
400406
Adds a margin around the bounding box, then transforms it into the target aspect
@@ -491,8 +497,12 @@ def generate_patches(self) -> list[tuple[int, int, int, int]]:
491497
Returns:
492498
A list of patch coordinates as tuples (x0, y0, x1, y1).
493499
"""
494-
patch_xs = self.split_array(self._shape[0], self._patch_counts[0], self._patch_overlap)
495-
patch_ys = self.split_array(self._shape[1], self._patch_counts[1], self._patch_overlap)
500+
patch_xs = self.split_array(
501+
self._shape[0], self._patch_counts[0], self._patch_overlap
502+
)
503+
patch_ys = self.split_array(
504+
self._shape[1], self._patch_counts[1], self._patch_overlap
505+
)
496506

497507
patches = []
498508
for y0, y1 in patch_ys:

0 commit comments

Comments
 (0)