|
| 1 | +""" |
| 2 | +Utils for the DLC-Live Model Zoo |
| 3 | +""" |
| 4 | +# NOTE JR 2026-23-01: This file contains duplicated code from the DeepLabCut main repository. |
| 5 | +# This should be removed once a solution is found to address duplicate code. |
| 6 | + |
| 7 | +import copy |
| 8 | +from pathlib import Path |
| 9 | +import logging |
| 10 | + |
| 11 | +from ruamel.yaml import YAML |
| 12 | + |
| 13 | +from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model |
| 14 | +from dlclive.modelzoo.resolve_config import update_config |
| 15 | + |
| 16 | +_MODELZOO_PATH = Path(__file__).parent |
| 17 | + |
| 18 | + |
| 19 | +def get_super_animal_model_config_path(model_name: str) -> Path: |
| 20 | + """Get the path to the model configuration file for a model and validate choice of model""" |
| 21 | + cfg_path = _MODELZOO_PATH / 'model_configs' / f"{model_name}.yaml" |
| 22 | + if not cfg_path.exists(): |
| 23 | + raise FileNotFoundError( |
| 24 | + f"Modelzoo model configuration file not found: {cfg_path} " |
| 25 | + f"Available models: {list_available_models()}" |
| 26 | + ) |
| 27 | + return cfg_path |
| 28 | + |
| 29 | + |
| 30 | +def get_super_animal_project_config_path(super_animal: str) -> Path: |
| 31 | + """Get the path to the project configuration file for a project and validate choice of project""" |
| 32 | + cfg_path = _MODELZOO_PATH / 'project_configs' / f"{super_animal}.yaml" |
| 33 | + if not cfg_path.exists(): |
| 34 | + raise FileNotFoundError( |
| 35 | + f"Modelzoo project configuration file not found: {cfg_path}" |
| 36 | + f"Available projects: {list_available_projects()}" |
| 37 | + ) |
| 38 | + return cfg_path |
| 39 | + |
| 40 | + |
| 41 | +def get_snapshot_folder_path() -> Path: |
| 42 | + return _MODELZOO_PATH / 'snapshots' |
| 43 | + |
| 44 | + |
| 45 | +def list_available_models() -> list[str]: |
| 46 | + return [p.stem for p in _MODELZOO_PATH.glob('model_configs/*.yaml')] |
| 47 | + |
| 48 | + |
| 49 | +def list_available_projects() -> list[str]: |
| 50 | + return [p.stem for p in _MODELZOO_PATH.glob('project_configs/*.yaml')] |
| 51 | + |
| 52 | + |
| 53 | +def list_available_combinations() -> list[str]: |
| 54 | + models = list_available_models() |
| 55 | + projects = list_available_projects() |
| 56 | + combinations = ['_'.join([p, m]) for p in projects for m in models] |
| 57 | + return combinations |
| 58 | + |
| 59 | + |
| 60 | +def read_config_as_dict(config_path: str | Path) -> dict: |
| 61 | + """ |
| 62 | + Args: |
| 63 | + config_path: the path to the configuration file to load |
| 64 | +
|
| 65 | + Returns: |
| 66 | + The configuration file with pure Python classes |
| 67 | + """ |
| 68 | + with open(config_path, "r") as f: |
| 69 | + cfg = YAML(typ='safe', pure=True).load(f) |
| 70 | + |
| 71 | + return cfg |
| 72 | + |
| 73 | + |
| 74 | +# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. |
| 75 | +def add_metadata(project_config: dict, config: dict,) -> dict: |
| 76 | + """Adds metadata to a pytorch pose configuration |
| 77 | +
|
| 78 | + Args: |
| 79 | + project_config: the project configuration |
| 80 | + config: the pytorch pose configuration |
| 81 | + pose_config_path: the path where the pytorch pose configuration will be saved |
| 82 | +
|
| 83 | + Returns: |
| 84 | + the configuration with a `meta` key added |
| 85 | + """ |
| 86 | + config = copy.deepcopy(config) |
| 87 | + config["metadata"] = { |
| 88 | + "project_path": project_config["project_path"], |
| 89 | + "pose_config_path": "", |
| 90 | + "bodyparts": project_config.get("multianimalbodyparts") or project_config["bodyparts"], |
| 91 | + "unique_bodyparts": project_config.get("uniquebodyparts", []), |
| 92 | + "individuals": project_config.get("individuals", ["animal"]), |
| 93 | + "with_identity": project_config.get("identity", False), |
| 94 | + } |
| 95 | + return config |
| 96 | + |
| 97 | + |
| 98 | +# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase. |
| 99 | +def load_super_animal_config( |
| 100 | + super_animal: str, |
| 101 | + model_name: str, |
| 102 | + detector_name: str | None = None, |
| 103 | + max_individuals: int = 30, |
| 104 | + device: str | None = None, |
| 105 | +) -> dict: |
| 106 | + """Loads the model configuration file for a model, detector and SuperAnimal |
| 107 | +
|
| 108 | + Args: |
| 109 | + super_animal: The name of the SuperAnimal for which to create the model config. |
| 110 | + model_name: The name of the model for which to create the model config. |
| 111 | + detector_name: The name of the detector for which to create the model config. |
| 112 | + max_individuals: The maximum number of detections to make in an image |
| 113 | + device: The device to use to train/run inference on the model |
| 114 | +
|
| 115 | + Returns: |
| 116 | + The model configuration for a SuperAnimal-pretrained model. |
| 117 | + """ |
| 118 | + project_cfg_path = get_super_animal_project_config_path(super_animal=super_animal) |
| 119 | + project_config = read_config_as_dict(project_cfg_path) |
| 120 | + |
| 121 | + model_cfg_path = get_super_animal_model_config_path(model_name=model_name) |
| 122 | + model_config = read_config_as_dict(model_cfg_path) |
| 123 | + model_config = add_metadata(project_config, model_config) |
| 124 | + model_config = update_config(model_config, max_individuals, device) |
| 125 | + |
| 126 | + if detector_name is None and super_animal != "superanimal_humanbody": |
| 127 | + model_config["method"] = "BU" |
| 128 | + else: |
| 129 | + model_config["method"] = "TD" |
| 130 | + if super_animal != "superanimal_humanbody": |
| 131 | + detector_cfg_path = get_super_animal_model_config_path( |
| 132 | + model_name=detector_name |
| 133 | + ) |
| 134 | + detector_cfg = read_config_as_dict(detector_cfg_path) |
| 135 | + model_config["detector"] = detector_cfg |
| 136 | + return model_config |
| 137 | + |
| 138 | + |
| 139 | +def download_super_animal_snapshot(dataset: str, model_name: str) -> Path: |
| 140 | + """Downloads a SuperAnimal snapshot |
| 141 | +
|
| 142 | + Args: |
| 143 | + dataset: The name of the SuperAnimal dataset for which to download a snapshot. |
| 144 | + model_name: The name of the model for which to download a snapshot. |
| 145 | +
|
| 146 | + Returns: |
| 147 | + The path to the downloaded snapshot. |
| 148 | +
|
| 149 | + Raises: |
| 150 | + RuntimeError if the model fails to download. |
| 151 | + """ |
| 152 | + snapshot_dir = get_snapshot_folder_path() |
| 153 | + model_name = f"{dataset}_{model_name}" |
| 154 | + model_filename = f"{model_name}.pt" |
| 155 | + model_path = snapshot_dir / model_filename |
| 156 | + |
| 157 | + if model_path.exists(): |
| 158 | + logging.info(f"Snapshot {model_path} already exists, skipping download") |
| 159 | + return model_path |
| 160 | + |
| 161 | + try: |
| 162 | + download_huggingface_model( |
| 163 | + model_name, target_dir=str(snapshot_dir), rename_mapping=model_filename |
| 164 | + ) |
| 165 | + |
| 166 | + if not model_path.exists(): |
| 167 | + raise RuntimeError(f"Failed to download {model_name} to {model_path}") |
| 168 | + |
| 169 | + except Exception as e: |
| 170 | + logging.error(f"Failed to download superanimal snapshot {model_name} to {model_path}: {e}") |
| 171 | + raise e |
| 172 | + |
| 173 | + return model_path |
| 174 | + |
| 175 | + |
0 commit comments