Skip to content

Commit 3a39801

Browse files
committed
Add modelzoo tests from DeepLabCut
1 parent 6828c54 commit 3a39801

1 file changed

Lines changed: 51 additions & 0 deletions

File tree

tests/test_modelzoo.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# NOTE JR 2026-23-01: This is duplicate code, copied from the original DeepLabCut-Live codebase.
2+
3+
import os
4+
5+
import pytest
6+
import dlclibrary
7+
from dlclibrary.dlcmodelzoo.modelzoo_download import MODELOPTIONS
8+
9+
from dlclive import modelzoo
10+
11+
12+
@pytest.mark.parametrize(
13+
"super_animal", ["superanimal_quadruped", "superanimal_topviewmouse"]
14+
)
15+
@pytest.mark.parametrize("model_name", ["hrnet_w32"])
16+
@pytest.mark.parametrize("detector_name", [None, "fasterrcnn_resnet50_fpn_v2"])
17+
def test_get_config_model_paths(super_animal, model_name, detector_name):
18+
model_config = modelzoo.load_super_animal_config(
19+
super_animal=super_animal,
20+
model_name=model_name,
21+
detector_name=detector_name,
22+
)
23+
24+
assert isinstance(model_config, dict)
25+
if detector_name is None:
26+
assert model_config["method"].lower() == "bu"
27+
assert "detector" not in model_config
28+
else:
29+
assert model_config["method"].lower() == "td"
30+
assert "detector" in model_config
31+
32+
33+
def test_download_huggingface_model(tmp_path_factory, model="full_cat"):
34+
folder = tmp_path_factory.mktemp("temp")
35+
dlclibrary.download_huggingface_model(model, str(folder))
36+
37+
assert os.path.exists(folder / "pose_cfg.yaml")
38+
assert any(f.startswith("snapshot-") for f in os.listdir(folder))
39+
# Verify that the Hugging Face folder was removed
40+
assert not any(f.startswith("models--") for f in os.listdir(folder))
41+
42+
43+
def test_download_huggingface_wrong_model():
44+
with pytest.raises(ValueError):
45+
dlclibrary.download_huggingface_model("wrong_model_name")
46+
47+
48+
@pytest.mark.skip(reason="slow")
49+
@pytest.mark.parametrize("model", MODELOPTIONS)
50+
def test_download_all_models(tmp_path_factory, model):
51+
test_download_huggingface_model(tmp_path_factory, model)

0 commit comments

Comments
 (0)