Skip to content

Commit 72d073b

Browse files
committed
add basic tests
1 parent 6e2ad64 commit 72d073b

8 files changed

Lines changed: 648 additions & 0 deletions

File tree

pytest.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
[pytest]
22
markers =
33
functional: functional tests
4+
5+
filterwarnings =
6+
# Suppress NumPy deprecation warning from Keras/TensorFlow about np.object
7+
ignore::FutureWarning:keras.*
8+
ignore::FutureWarning:tensorflow.*
9+
ignore:In the future `np.object` will be defined:FutureWarning

tests/test_config.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
Tests for configuration file reading
3+
"""
4+
import pytest
5+
from pathlib import Path
6+
from dlclive.core import config
7+
8+
9+
class TestConfig:
10+
"""Test configuration file reading"""
11+
12+
def test_read_yaml_success(self, tmp_path):
13+
"""Test successfully reading a YAML config file"""
14+
config_file = tmp_path / "pose_cfg.yaml"
15+
# Write YAML content directly
16+
yaml_content = """num_joints: 17
17+
all_joints:
18+
- head
19+
- neck
20+
- shoulder
21+
batch_size: 1
22+
"""
23+
config_file.write_text(yaml_content)
24+
25+
result = config.read_yaml(config_file)
26+
27+
assert result["num_joints"] == 17
28+
assert result["all_joints"] == ["head", "neck", "shoulder"]
29+
assert result["batch_size"] == 1
30+
31+
def test_read_yaml_nonexistent(self, tmp_path):
32+
"""Test that nonexistent config files raise FileNotFoundError"""
33+
nonexistent = tmp_path / "nonexistent.yaml"
34+
35+
with pytest.raises(FileNotFoundError):
36+
config.read_yaml(nonexistent)
37+
38+
def test_read_yaml_path_resolution(self, tmp_path):
39+
"""Test that paths are properly resolved"""
40+
config_file = tmp_path / "config.yaml"
41+
config_file.write_text("test: value")
42+
43+
# Test with relative path
44+
result = config.read_yaml(str(config_file))
45+
assert result["test"] == "value"
46+
47+
# Test with Path object
48+
result = config.read_yaml(config_file)
49+
assert result["test"] == "value"
50+

tests/test_dlclive.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""
2+
Tests for DLCLive core functionality - frame processing, cropping, etc.
3+
"""
4+
import pytest
5+
import numpy as np
6+
from pathlib import Path
7+
from unittest.mock import Mock, MagicMock, patch
8+
from dlclive import DLCLive
9+
from dlclive.exceptions import DLCLiveError
10+
11+
12+
class TestDLCLive:
13+
"""Test DLCLive class core functionality"""
14+
15+
@pytest.fixture
16+
def mock_runner(self):
17+
"""Create a mock runner for testing"""
18+
runner = Mock()
19+
runner.cfg = {"test": "config"}
20+
runner.precision = "FP32"
21+
runner.init_inference.return_value = np.zeros((17, 3))
22+
runner.get_pose.return_value = np.zeros((17, 3))
23+
runner.close.return_value = None
24+
runner.read_config.return_value = {"test": "config"}
25+
return runner
26+
27+
@pytest.fixture
28+
def sample_frame(self):
29+
"""Create a sample frame for testing"""
30+
return np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
31+
32+
@patch('dlclive.factory.build_runner')
33+
def test_dlclive_initialization(self, mock_build_runner, mock_runner, tmp_path):
34+
"""Test DLCLive initialization"""
35+
model_path = tmp_path / "model.pt"
36+
model_path.write_text("test")
37+
mock_build_runner.return_value = mock_runner
38+
39+
dlc = DLCLive(model_path, model_type="pytorch")
40+
41+
assert dlc.path == model_path
42+
assert dlc.model_type == "pytorch"
43+
assert dlc.is_initialized == False
44+
assert dlc.cropping is None
45+
assert dlc.dynamic == (False, 0.5, 10)
46+
assert dlc.processor is None
47+
48+
@patch('dlclive.factory.build_runner')
49+
def test_dlclive_cfg_property(self, mock_build_runner, mock_runner, tmp_path):
50+
"""Test accessing cfg property"""
51+
model_path = tmp_path / "model.pt"
52+
model_path.write_text("test")
53+
mock_build_runner.return_value = mock_runner
54+
55+
dlc = DLCLive(model_path)
56+
assert dlc.cfg == {"test": "config"}
57+
58+
@patch('dlclive.factory.build_runner')
59+
def test_dlclive_precision_property(self, mock_build_runner, mock_runner, tmp_path):
60+
"""Test accessing precision property"""
61+
model_path = tmp_path / "model.pt"
62+
model_path.write_text("test")
63+
mock_build_runner.return_value = mock_runner
64+
65+
dlc = DLCLive(model_path)
66+
assert dlc.precision == "FP32"
67+
68+
@patch('dlclive.factory.build_runner')
69+
def test_dlclive_read_config(self, mock_build_runner, mock_runner, tmp_path):
70+
"""Test reading configuration"""
71+
model_path = tmp_path / "model.pt"
72+
model_path.write_text("test")
73+
mock_build_runner.return_value = mock_runner
74+
75+
dlc = DLCLive(model_path)
76+
config = dlc.read_config()
77+
assert config == {"test": "config"}
78+
79+
@patch('dlclive.factory.build_runner')
80+
def test_dlclive_parameterization(self, mock_build_runner, mock_runner, tmp_path):
81+
"""Test parameterization property"""
82+
model_path = tmp_path / "model.pt"
83+
model_path.write_text("test")
84+
mock_build_runner.return_value = mock_runner
85+
86+
dlc = DLCLive(model_path, cropping=[10, 100, 20, 200])
87+
params = dlc.parameterization
88+
89+
assert "path" in params
90+
assert "cfg" in params
91+
assert "model_type" in params
92+
assert params["cropping"] == [10, 100, 20, 200]
93+
94+
@patch('dlclive.factory.build_runner')
95+
@patch('dlclive.utils.img_to_rgb')
96+
def test_process_frame_cropping(self, mock_img_to_rgb, mock_build_runner,
97+
mock_runner, sample_frame, tmp_path):
98+
"""Test frame processing with cropping"""
99+
model_path = tmp_path / "model.pt"
100+
model_path.write_text("test")
101+
mock_build_runner.return_value = mock_runner
102+
mock_img_to_rgb.side_effect = lambda x: x
103+
104+
dlc = DLCLive(model_path, cropping=[10, 100, 20, 200])
105+
result = dlc.process_frame(sample_frame)
106+
107+
# Check that cropping was applied (result should be smaller)
108+
assert result.shape[0] == 180 # 200 - 20
109+
assert result.shape[1] == 90 # 100 - 10
110+
111+
@patch('dlclive.factory.build_runner')
112+
@patch('dlclive.utils.resize_frame')
113+
@patch('dlclive.utils.img_to_rgb')
114+
def test_process_frame_resize(self, mock_img_to_rgb, mock_resize,
115+
mock_build_runner, mock_runner, sample_frame, tmp_path):
116+
"""Test frame processing with resize"""
117+
model_path = tmp_path / "model.pt"
118+
model_path.write_text("test")
119+
mock_build_runner.return_value = mock_runner
120+
mock_img_to_rgb.side_effect = lambda x: x
121+
mock_resize.side_effect = lambda x, resize: x # No actual resize in test
122+
123+
dlc = DLCLive(model_path, resize=0.5)
124+
dlc.process_frame(sample_frame)
125+
126+
mock_resize.assert_called_once()
127+
128+
@patch('dlclive.factory.build_runner')
129+
def test_process_frame_dynamic_cropping(self, mock_build_runner, mock_runner,
130+
sample_frame, tmp_path):
131+
"""Test dynamic cropping functionality"""
132+
model_path = tmp_path / "model.pt"
133+
model_path.write_text("test")
134+
mock_build_runner.return_value = mock_runner
135+
136+
# Create pose with detected body parts
137+
pose = np.array([[100, 150, 0.8], [120, 160, 0.9], [80, 140, 0.7]])
138+
139+
dlc = DLCLive(model_path, dynamic=(True, 0.5, 10))
140+
dlc.pose = pose
141+
142+
result = dlc.process_frame(sample_frame)
143+
144+
# Check that dynamic cropping was applied
145+
assert dlc.dynamic_cropping is not None
146+
assert len(dlc.dynamic_cropping) == 4
147+
148+
@patch('dlclive.factory.build_runner')
149+
def test_init_inference_no_frame(self, mock_build_runner, mock_runner, tmp_path):
150+
"""Test that init_inference raises error with no frame"""
151+
model_path = tmp_path / "model.pt"
152+
model_path.write_text("test")
153+
mock_build_runner.return_value = mock_runner
154+
155+
dlc = DLCLive(model_path)
156+
157+
with pytest.raises(DLCLiveError, match="No frame provided"):
158+
dlc.init_inference()
159+
160+
@patch('dlclive.factory.build_runner')
161+
def test_get_pose_no_frame(self, mock_build_runner, mock_runner, tmp_path):
162+
"""Test that get_pose raises error with no frame"""
163+
model_path = tmp_path / "model.pt"
164+
model_path.write_text("test")
165+
mock_build_runner.return_value = mock_runner
166+
167+
dlc = DLCLive(model_path)
168+
169+
with pytest.raises(DLCLiveError, match="No frame provided"):
170+
dlc.get_pose()
171+
172+
@patch('dlclive.factory.build_runner')
173+
def test_close(self, mock_build_runner, mock_runner, tmp_path):
174+
"""Test closing DLCLive instance"""
175+
model_path = tmp_path / "model.pt"
176+
model_path.write_text("test")
177+
mock_build_runner.return_value = mock_runner
178+
179+
dlc = DLCLive(model_path)
180+
dlc.is_initialized = True
181+
dlc.close()
182+
183+
assert dlc.is_initialized == False
184+
mock_runner.close.assert_called_once()
185+
186+
@patch('dlclive.factory.build_runner')
187+
def test_post_process_pose_with_processor(self, mock_build_runner, mock_runner,
188+
sample_frame, tmp_path):
189+
"""Test pose post-processing with processor"""
190+
model_path = tmp_path / "model.pt"
191+
model_path.write_text("test")
192+
mock_build_runner.return_value = mock_runner
193+
194+
# Create mock processor
195+
mock_processor = Mock()
196+
mock_processor.process.return_value = np.ones((17, 3))
197+
198+
dlc = DLCLive(model_path, processor=mock_processor)
199+
dlc.pose = np.zeros((17, 3))
200+
201+
# Manually call _post_process_pose
202+
result = dlc._post_process_pose(sample_frame)
203+
204+
mock_processor.process.assert_called_once()
205+
np.testing.assert_array_equal(result, np.ones((17, 3)))
206+
207+

tests/test_engine.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Tests for the Engine class - engine detection and model type handling
3+
"""
4+
import pytest
5+
from pathlib import Path
6+
from dlclive.engine import Engine
7+
8+
9+
class TestEngine:
10+
"""Test Engine enum and detection methods"""
11+
12+
def test_engine_from_model_type_pytorch(self):
13+
"""Test detecting PyTorch engine from model type"""
14+
assert Engine.from_model_type("pytorch") == Engine.PYTORCH
15+
assert Engine.from_model_type("PyTorch") == Engine.PYTORCH
16+
assert Engine.from_model_type("PYTORCH") == Engine.PYTORCH
17+
18+
def test_engine_from_model_type_tensorflow(self):
19+
"""Test detecting TensorFlow engine from model type"""
20+
assert Engine.from_model_type("tensorflow") == Engine.TENSORFLOW
21+
assert Engine.from_model_type("base") == Engine.TENSORFLOW
22+
assert Engine.from_model_type("tensorrt") == Engine.TENSORFLOW
23+
assert Engine.from_model_type("lite") == Engine.TENSORFLOW
24+
25+
def test_engine_from_model_type_invalid(self):
26+
"""Test that invalid model types raise ValueError"""
27+
with pytest.raises(ValueError, match="Unknown model type"):
28+
Engine.from_model_type("invalid")
29+
30+
def test_engine_from_model_path_tensorflow_dir(self, tmp_path):
31+
"""Test detecting TensorFlow engine from directory with .pb and pose_cfg.yaml"""
32+
model_dir = tmp_path / "tensorflow_model"
33+
model_dir.mkdir()
34+
(model_dir / "pose_cfg.yaml").write_text("test")
35+
(model_dir / "snapshot-100.pb").write_text("test")
36+
37+
assert Engine.from_model_path(model_dir) == Engine.TENSORFLOW
38+
39+
def test_engine_from_model_path_pytorch_file(self, tmp_path):
40+
"""Test detecting PyTorch engine from .pt file"""
41+
model_file = tmp_path / "model.pt"
42+
model_file.write_text("test")
43+
44+
assert Engine.from_model_path(model_file) == Engine.PYTORCH
45+
46+
def test_engine_from_model_path_nonexistent(self, tmp_path):
47+
"""Test that nonexistent paths raise FileNotFoundError"""
48+
nonexistent = tmp_path / "nonexistent"
49+
with pytest.raises(FileNotFoundError):
50+
Engine.from_model_path(nonexistent)
51+
52+
def test_engine_from_model_path_invalid(self, tmp_path):
53+
"""Test that invalid model paths raise ValueError"""
54+
# Directory without required files
55+
invalid_dir = tmp_path / "invalid"
56+
invalid_dir.mkdir()
57+
with pytest.raises(ValueError, match="Could not determine engine"):
58+
Engine.from_model_path(invalid_dir)
59+
60+
# File with wrong extension
61+
wrong_ext = tmp_path / "model.txt"
62+
wrong_ext.write_text("test")
63+
with pytest.raises(ValueError, match="Could not determine engine"):
64+
Engine.from_model_path(wrong_ext)
65+
66+

tests/test_exceptions.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
Tests for exception classes
3+
"""
4+
import pytest
5+
from dlclive.exceptions import DLCLiveError, DLCLiveWarning
6+
7+
8+
class TestExceptions:
9+
"""Test exception and warning classes"""
10+
11+
def test_dlclive_error(self):
12+
"""Test DLCLiveError can be raised and caught"""
13+
with pytest.raises(DLCLiveError):
14+
raise DLCLiveError("Test error message")
15+
16+
def test_dlclive_error_message(self):
17+
"""Test DLCLiveError preserves error message"""
18+
error_msg = "Custom error message"
19+
with pytest.raises(DLCLiveError, match=error_msg):
20+
raise DLCLiveError(error_msg)
21+
22+
def test_dlclive_error_inheritance(self):
23+
"""Test DLCLiveError is an Exception"""
24+
error = DLCLiveError("test")
25+
assert isinstance(error, Exception)
26+
27+
def test_dlclive_warning(self):
28+
"""Test DLCLiveWarning can be issued"""
29+
with pytest.warns(DLCLiveWarning):
30+
import warnings
31+
warnings.warn("Test warning", DLCLiveWarning)
32+
33+
def test_dlclive_warning_inheritance(self):
34+
"""Test DLCLiveWarning is a Warning"""
35+
assert issubclass(DLCLiveWarning, Warning)
36+
37+

0 commit comments

Comments
 (0)