Skip to content

Commit eb15fa8

Browse files
authored
Merge pull request #163 from DeepLabCut/jaap/update_modelzoo
Fix modelzoo configs: add to package-data
2 parents b05bdb0 + 6062dc2 commit eb15fa8

4 files changed

Lines changed: 21 additions & 6 deletions

File tree

MANIFEST.in

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
include dlclive/check_install/*
1+
include dlclive/check_install/*
2+
include dlclive/modelzoo/model_configs/*.yaml
3+
include dlclive/modelzoo/project_configs/*.yaml

dlclive/modelzoo/pytorch_model_zoo_export.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,18 @@ def export_modelzoo_model(
1414
detector_name: str | None = None,
1515
) -> None:
1616
"""
17+
Export a DeepLabCut Model Zoo model to a single .pt file.
1718
19+
Downloads the model configuration and weights from HuggingFace, bundles them
20+
together (optionally with a detector), and saves as a single torch archive.
21+
Skips export if the output file already exists.
22+
23+
Args:
24+
export_path: Arbitrary destination path for the exported .pt file.
25+
super_animal: Super animal dataset name (e.g. "superanimal_quadruped").
26+
model_name: Pose model architecture name (e.g. "resnet_50").
27+
detector_name: Optional detector model name. If provided, detector
28+
weights are included in the export.
1829
"""
1930
Path(export_path).parent.mkdir(parents=True, exist_ok=True)
2031
if Path(export_path).exists():

dlclive/modelzoo/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pathlib import Path
1010

1111
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
12+
from dlclibrary.dlcmodelzoo.modelzoo_download import _load_model_names as huggingface_model_paths
1213
from ruamel.yaml import YAML
1314

1415
from dlclive.modelzoo.resolve_config import update_config
@@ -49,10 +50,7 @@ def list_available_projects() -> list[str]:
4950

5051

5152
def list_available_combinations() -> list[str]:
52-
models = list_available_models()
53-
projects = list_available_projects()
54-
combinations = ["_".join([p, m]) for p in projects for m in models]
55-
return combinations
53+
return list(huggingface_model_paths.keys())
5654

5755

5856
def read_config_as_dict(config_path: str | Path) -> dict:

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ include-package-data = true
8989
include = ["dlclive*"]
9090

9191
[tool.setuptools.package-data]
92-
dlclive = ["check_install/*"]
92+
dlclive = [
93+
"check_install/*",
94+
"modelzoo/model_configs/*.yaml",
95+
"modelzoo/project_configs/*.yaml",
96+
]
9397

9498
# [tool.ruff]
9599
# lint.select = ["E", "F", "B", "I", "UP"]

0 commit comments

Comments
 (0)