Skip to content

Commit f105709

Browse files
committed
Infer n_individuals and and n_bodyparts from config metadata.
- add fields for n_individuals and n_bodyparts read from the pytorch model config. - add option to infer single_animal mode from n_individuals in model config. This commit does not change default behaviour. Only when passing single_animal = None explicitly to PyTorchRunner, the single_animal mode will be inferred from the model config.
1 parent 595c295 commit f105709

1 file changed

Lines changed: 17 additions & 3 deletions

File tree

dlclive/pose_estimation_pytorch/runner.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ class PyTorchRunner(BaseRunner):
118118
path: The path to the model to run inference with.
119119
device: The device on which to run inference, e.g. "cpu", "cuda", "cuda:0"
120120
precision: The precision of the model. One of "FP16" or "FP32".
121-
single_animal: This option is only available for single-animal pose estimation
121+
single_animal: bool | None, default=True
122+
Set to True if the model is a single-animal model, False if it is a multi-animal model.
123+
If set to None, single_animal mode will be inferred from the model configuration.
124+
This option is introduced for single-animal pose estimation
122125
models. It makes the code behave in exactly the same way as DeepLabCut-Live
123126
with version < 3.0.0. This ensures backwards compatibility with any
124127
Processors that were implemented.
@@ -131,15 +134,16 @@ def __init__(
131134
path: str | Path,
132135
device: str = "auto",
133136
precision: Literal["FP16", "FP32"] = "FP32",
134-
single_animal: bool = True,
137+
single_animal: bool | None = None,
135138
dynamic: dict | dynamic_cropping.DynamicCropper | None = None,
136139
top_down_config: dict | TopDownConfig | None = None,
137140
) -> None:
138141
super().__init__(path)
139142
self.device = _parse_device(device)
140143
self.precision = precision
141144
self.single_animal = single_animal
142-
145+
self.n_individuals = None
146+
self.n_bodyparts = None
143147
self.cfg = None
144148
self.detector = None
145149
self.model = None
@@ -259,6 +263,16 @@ def load_model(self) -> None:
259263
raw_data = torch.load(self.path, map_location="cpu", weights_only=True)
260264

261265
self.cfg = raw_data["config"]
266+
267+
# Infer n_bodyparts and n_individuals from model configuration
268+
individuals = self.cfg.get("metadata", {}).get("individuals", ['idv1'])
269+
bodyparts = self.cfg.get("metadata", {}).get("bodyparts", [])
270+
self.n_individuals = len(individuals)
271+
self.n_bodyparts = len(bodyparts)
272+
# If single_animal is not set, infer it from n_individuals in model configuration
273+
if self.single_animal is None:
274+
self.single_animal = self.n_individuals == 1
275+
262276
self.model = models.PoseModel.build(self.cfg["model"])
263277
self.model.load_state_dict(raw_data["pose"])
264278
self.model = self.model.to(self.device)

0 commit comments

Comments
 (0)