Skip to content

Commit 9ee9768

Browse files
committed
refactor modelzoo export: more explicit handling of torchvision detector case
1 parent 15b265c commit 9ee9768

2 files changed

Lines changed: 22 additions & 12 deletions

File tree

dlclive/modelzoo/pytorch_model_zoo_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def export_modelzoo_model(
2424
export_path: Arbitrary destination path for the exported .pt file.
2525
super_animal: Super animal dataset name (e.g. "superanimal_quadruped").
2626
model_name: Pose model architecture name (e.g. "resnet_50").
27-
detector_name: Optional detector model name. If provided, detector
27+
detector_name: Detector model name for top-down models. If provided, detector
2828
weights are included in the export.
2929
"""
3030
Path(export_path).parent.mkdir(parents=True, exist_ok=True)

dlclive/modelzoo/utils.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,21 @@ def add_metadata(
9797
return config
9898

9999

100+
def _get_torchvision_detector_config(detector_name: str) -> dict:
101+
"""Get a torchvision detector configuration for the superanimal humanbody model"""
102+
if detector_name is None:
103+
raise ValueError(f"Detector name is required for superanimal humanbody models. Must be one of {SUPPORTED_TORCHVISION_DETECTORS}.")
104+
if detector_name not in SUPPORTED_TORCHVISION_DETECTORS:
105+
raise ValueError(f"Unsupported humanbody detector {detector_name}. Should be one of {SUPPORTED_TORCHVISION_DETECTORS}")
106+
return {
107+
"type": "TorchvisionDetectorAdaptor",
108+
"model": detector_name,
109+
"weights": "COCO_V1",
110+
"num_classes": None,
111+
"box_score_thresh": 0.6,
112+
}
113+
114+
100115
# NOTE - DUPLICATED @deruyter92 2026-01-23: Copied from the original DeepLabCut codebase
101116
# from deeplabcut/pose_estimation_pytorch/modelzoo/utils.py
102117
def load_super_animal_config(
@@ -126,7 +141,7 @@ def load_super_animal_config(
126141
model_config = add_metadata(project_config, model_config)
127142
model_config = update_config(model_config, max_individuals, device)
128143

129-
if detector_name is None and super_animal != "superanimal_humanbody":
144+
if detector_name is None:
130145
model_config["method"] = "BU"
131146
else:
132147
model_config["method"] = "TD"
@@ -135,16 +150,11 @@ def load_super_animal_config(
135150
)
136151
detector_cfg = read_config_as_dict(detector_cfg_path)
137152
model_config["detector"] = detector_cfg
138-
if super_animal == "superanimal_humanbody":
139-
# Apply specific updates required to run the torchvision detector with pretrained weights
140-
assert detector_name in SUPPORTED_TORCHVISION_DETECTORS
141-
model_config["detector"]['model']= {
142-
"type": "TorchvisionDetectorAdaptor",
143-
"model": detector_name,
144-
"weights": "COCO_V1",
145-
"num_classes": None,
146-
"box_score_thresh": 0.6,
147-
}
153+
154+
if super_animal == "superanimal_humanbody":
155+
# Raises ValueError if Detector name is not one of SUPPORTED_TORCHVISION_DETECTORS
156+
torchvision_detector_config = _get_torchvision_detector_config(detector_name)
157+
model_config["detector"]["model"] = torchvision_detector_config
148158
return model_config
149159

150160

0 commit comments

Comments
 (0)