Skip to content

Commit e38e3ad

Browse files
committed
add modelzoo utils (download and config)
1 parent 05ac8f0 commit e38e3ad

4 files changed

Lines changed: 374 additions & 0 deletions

File tree

dlclive/modelzoo/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from dlclive.modelzoo.utils import (
2+
_MODELZOO_PATH,
3+
list_available_models,
4+
list_available_projects,
5+
list_available_combinations,
6+
load_super_animal_config,
7+
download_super_animal_snapshot,
8+
)
9+
from dlclive.modelzoo.pytorch_model_zoo_export import export_modelzoo_model
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import warnings
2+
from pathlib import Path
3+
from collections import OrderedDict
4+
5+
import torch
6+
7+
from dlclive.modelzoo.utils import load_super_animal_config, download_super_animal_snapshot
8+
9+
10+
def export_modelzoo_model(
11+
export_path: str | Path,
12+
super_animal: str,
13+
model_name: str,
14+
detector_name: str | None = None,
15+
) -> None:
16+
"""
17+
18+
"""
19+
Path(export_path).parent.mkdir(parents=True, exist_ok=True)
20+
if Path(export_path).exists():
21+
warnings.warn(f"Export path {export_path} already exists, skipping export", UserWarning)
22+
return
23+
24+
model_cfg = load_super_animal_config(
25+
super_animal=super_animal,
26+
model_name=model_name,
27+
detector_name=detector_name,
28+
)
29+
30+
def _load_model_weights(model_name: str, super_animal: str = super_animal) -> OrderedDict:
31+
"""Download the model weights from huggingface and load them in torch state dict"""
32+
checkpoint: Path = download_super_animal_snapshot(dataset=super_animal, model_name=model_name)
33+
return torch.load(checkpoint, map_location="cpu", weights_only=True)["model"]
34+
35+
export_dict = {
36+
"config": model_cfg,
37+
"pose": _load_model_weights(model_name),
38+
"detector": _load_model_weights(detector_name) if detector_name is not None else None,
39+
}
40+
torch.save(export_dict, export_path)
41+
42+
43+
if __name__ == "__main__":
44+
"""Example usage"""
45+
from utils import _MODELZOO_PATH
46+
47+
model_name = "resnet_50"
48+
super_animal = "superanimal_quadruped"
49+
50+
export_modelzoo_model(
51+
export_path=_MODELZOO_PATH / 'exported_models' / f'exported_{super_animal}_{model_name}.pt',
52+
super_animal=super_animal,
53+
model_name=model_name,
54+
)

dlclive/modelzoo/resolve_config.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
Helper function to deal with default values in the model configuration.
3+
For instance, "num_bodyparts x 2" is replaced with the number of bodyparts multiplied by 2.
4+
"""
5+
# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase.
6+
7+
import copy
8+
9+
10+
def update_config(config: dict, max_individuals: int, device: str):
11+
"""Loads the model configuration file for a model, detector and SuperAnimal
12+
13+
Args:
14+
config: The default model configuration file.
15+
max_individuals: The maximum number of detections to make in an image
16+
device: The device to use to train/run inference on the model
17+
18+
Returns:
19+
The model configuration for a SuperAnimal-pretrained model.
20+
"""
21+
config = replace_default_values(
22+
config,
23+
num_bodyparts=len(config["metadata"]["bodyparts"]),
24+
num_individuals=max_individuals,
25+
backbone_output_channels=config["model"]["backbone_output_channels"],
26+
)
27+
config["metadata"]["individuals"] = [f"animal{i}" for i in range(max_individuals)]
28+
29+
config["device"] = device
30+
if config.get("detector", None) is not None:
31+
config["detector"]["device"] = device
32+
33+
return config
34+
35+
36+
def replace_default_values(
37+
config: dict | list,
38+
num_bodyparts: int | None = None,
39+
num_individuals: int | None = None,
40+
backbone_output_channels: int | None = None,
41+
**kwargs,
42+
) -> dict:
43+
"""Replaces placeholder values in a model configuration with their actual values.
44+
45+
This method allows to create template PyTorch configurations for models with values
46+
such as "num_bodyparts", which are replaced with the number of bodyparts for a
47+
project when making its Pytorch configuration.
48+
49+
This code can also do some basic arithmetic. You can write "num_bodyparts x 2" (or
50+
any factor other than 2) for location refinement channels, and the number of
51+
channels will be twice the number of bodyparts. You can write
52+
"backbone_output_channels // 2" for the number of channels in a layer, and it will
53+
be half the number of channels output by the backbone. You can write
54+
"num_bodyparts + 1" (such as for DEKR heatmaps, where a "center" bodypart is added).
55+
56+
The three base placeholder values that can be computed are "num_bodyparts",
57+
"num_individuals" and "backbone_output_channels". You can add more through the
58+
keyword arguments (such as "paf_graph": list[tuple[int, int]] or
59+
"paf_edges_to_keep": list[int] for DLCRNet models).
60+
61+
Args:
62+
config: the configuration in which to replace default values
63+
num_bodyparts: the number of bodyparts
64+
num_individuals: the number of individuals
65+
backbone_output_channels: the number of backbone output channels
66+
kwargs: other placeholder values to fill in
67+
68+
Returns:
69+
the configuration with placeholder values replaced
70+
71+
Raises:
72+
ValueError: if there is a placeholder value who's "updated" value was not
73+
given to the method
74+
"""
75+
76+
def get_updated_value(variable: str) -> int | list[int]:
77+
var_parts = variable.strip().split(" ")
78+
var_name = var_parts[0]
79+
if updated_values[var_name] is None:
80+
raise ValueError(
81+
f"Found {variable} in the configuration file, but there is no default "
82+
f"value for this variable."
83+
)
84+
85+
if len(var_parts) == 1:
86+
return updated_values[var_name]
87+
elif len(var_parts) == 3:
88+
operator, factor = var_parts[1], var_parts[2]
89+
if not factor.isdigit():
90+
raise ValueError(f"F must be an integer in variable: {variable}")
91+
92+
factor = int(factor)
93+
if operator == "+":
94+
return updated_values[var_name] + factor
95+
elif operator == "x":
96+
return updated_values[var_name] * factor
97+
elif operator == "//":
98+
return updated_values[var_name] // factor
99+
else:
100+
raise ValueError(f"Unknown operator for variable: {variable}")
101+
102+
raise ValueError(
103+
f"Found {variable} in the configuration file, but cannot parse it."
104+
)
105+
106+
updated_values = {
107+
"num_bodyparts": num_bodyparts,
108+
"num_individuals": num_individuals,
109+
"backbone_output_channels": backbone_output_channels,
110+
**kwargs,
111+
}
112+
113+
config = copy.deepcopy(config)
114+
if isinstance(config, dict):
115+
keys_to_update = list(config.keys())
116+
elif isinstance(config, list):
117+
keys_to_update = range(len(config))
118+
else:
119+
raise ValueError(f"Config to update must be dict or list, found {type(config)}")
120+
121+
for k in keys_to_update:
122+
if isinstance(config[k], (list, dict)):
123+
config[k] = replace_default_values(
124+
config[k],
125+
num_bodyparts,
126+
num_individuals,
127+
backbone_output_channels,
128+
**kwargs,
129+
)
130+
elif (
131+
isinstance(config[k], str)
132+
and config[k].strip().split(" ")[0] in updated_values.keys()
133+
):
134+
config[k] = get_updated_value(config[k])
135+
136+
return config

dlclive/modelzoo/utils.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""
2+
Utils for the DLC-Live Model Zoo
3+
"""
4+
# NOTE JR 2026-23-01: This file contains duplicated code from the DeepLabCut main repository.
5+
# This should be removed once a solution is found to address duplicate code.
6+
7+
import copy
8+
from pathlib import Path
9+
import logging
10+
11+
from ruamel.yaml import YAML
12+
13+
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
14+
from dlclive.modelzoo.resolve_config import update_config
15+
16+
_MODELZOO_PATH = Path(__file__).parent
17+
18+
19+
def get_super_animal_model_config_path(model_name: str) -> Path:
20+
"""Get the path to the model configuration file for a model and validate choice of model"""
21+
cfg_path = _MODELZOO_PATH / 'model_configs' / f"{model_name}.yaml"
22+
if not cfg_path.exists():
23+
raise FileNotFoundError(
24+
f"Modelzoo model configuration file not found: {cfg_path} "
25+
f"Available models: {list_available_models()}"
26+
)
27+
return cfg_path
28+
29+
30+
def get_super_animal_project_config_path(super_animal: str) -> Path:
31+
"""Get the path to the project configuration file for a project and validate choice of project"""
32+
cfg_path = _MODELZOO_PATH / 'project_configs' / f"{super_animal}.yaml"
33+
if not cfg_path.exists():
34+
raise FileNotFoundError(
35+
f"Modelzoo project configuration file not found: {cfg_path}"
36+
f"Available projects: {list_available_projects()}"
37+
)
38+
return cfg_path
39+
40+
41+
def get_snapshot_folder_path() -> Path:
42+
return _MODELZOO_PATH / 'snapshots'
43+
44+
45+
def list_available_models() -> list[str]:
46+
return [p.stem for p in _MODELZOO_PATH.glob('model_configs/*.yaml')]
47+
48+
49+
def list_available_projects() -> list[str]:
50+
return [p.stem for p in _MODELZOO_PATH.glob('project_configs/*.yaml')]
51+
52+
53+
def list_available_combinations() -> list[str]:
54+
models = list_available_models()
55+
projects = list_available_projects()
56+
combinations = ['_'.join([p, m]) for p in projects for m in models]
57+
return combinations
58+
59+
60+
def read_config_as_dict(config_path: str | Path) -> dict:
61+
"""
62+
Args:
63+
config_path: the path to the configuration file to load
64+
65+
Returns:
66+
The configuration file with pure Python classes
67+
"""
68+
with open(config_path, "r") as f:
69+
cfg = YAML(typ='safe', pure=True).load(f)
70+
71+
return cfg
72+
73+
74+
# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase.
75+
def add_metadata(project_config: dict, config: dict,) -> dict:
76+
"""Adds metadata to a pytorch pose configuration
77+
78+
Args:
79+
project_config: the project configuration
80+
config: the pytorch pose configuration
81+
pose_config_path: the path where the pytorch pose configuration will be saved
82+
83+
Returns:
84+
the configuration with a `meta` key added
85+
"""
86+
config = copy.deepcopy(config)
87+
config["metadata"] = {
88+
"project_path": project_config["project_path"],
89+
"pose_config_path": "",
90+
"bodyparts": project_config.get("multianimalbodyparts") or project_config["bodyparts"],
91+
"unique_bodyparts": project_config.get("uniquebodyparts", []),
92+
"individuals": project_config.get("individuals", ["animal"]),
93+
"with_identity": project_config.get("identity", False),
94+
}
95+
return config
96+
97+
98+
# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase.
99+
def load_super_animal_config(
100+
super_animal: str,
101+
model_name: str,
102+
detector_name: str | None = None,
103+
max_individuals: int = 30,
104+
device: str | None = None,
105+
) -> dict:
106+
"""Loads the model configuration file for a model, detector and SuperAnimal
107+
108+
Args:
109+
super_animal: The name of the SuperAnimal for which to create the model config.
110+
model_name: The name of the model for which to create the model config.
111+
detector_name: The name of the detector for which to create the model config.
112+
max_individuals: The maximum number of detections to make in an image
113+
device: The device to use to train/run inference on the model
114+
115+
Returns:
116+
The model configuration for a SuperAnimal-pretrained model.
117+
"""
118+
project_cfg_path = get_super_animal_project_config_path(super_animal=super_animal)
119+
project_config = read_config_as_dict(project_cfg_path)
120+
121+
model_cfg_path = get_super_animal_model_config_path(model_name=model_name)
122+
model_config = read_config_as_dict(model_cfg_path)
123+
model_config = add_metadata(project_config, model_config)
124+
model_config = update_config(model_config, max_individuals, device)
125+
126+
if detector_name is None and super_animal != "superanimal_humanbody":
127+
model_config["method"] = "BU"
128+
else:
129+
model_config["method"] = "TD"
130+
if super_animal != "superanimal_humanbody":
131+
detector_cfg_path = get_super_animal_model_config_path(
132+
model_name=detector_name
133+
)
134+
detector_cfg = read_config_as_dict(detector_cfg_path)
135+
model_config["detector"] = detector_cfg
136+
return model_config
137+
138+
139+
def download_super_animal_snapshot(dataset: str, model_name: str) -> Path:
140+
"""Downloads a SuperAnimal snapshot
141+
142+
Args:
143+
dataset: The name of the SuperAnimal dataset for which to download a snapshot.
144+
model_name: The name of the model for which to download a snapshot.
145+
146+
Returns:
147+
The path to the downloaded snapshot.
148+
149+
Raises:
150+
RuntimeError if the model fails to download.
151+
"""
152+
snapshot_dir = get_snapshot_folder_path()
153+
model_name = f"{dataset}_{model_name}"
154+
model_filename = f"{model_name}.pt"
155+
model_path = snapshot_dir / model_filename
156+
157+
if model_path.exists():
158+
logging.info(f"Snapshot {model_path} already exists, skipping download")
159+
return model_path
160+
161+
try:
162+
download_huggingface_model(
163+
model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename
164+
)
165+
166+
if not model_path.exists():
167+
raise RuntimeError(f"Failed to download {model_name} to {model_path}")
168+
169+
except Exception as e:
170+
logging.error(f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}")
171+
raise e
172+
173+
return model_path
174+
175+

0 commit comments

Comments
 (0)