Skip to content

Commit 08ba8a4

Browse files
committed
dlc_processor: add lightweight validation of shape and dtype (+tests)
1 parent fd605f8 commit 08ba8a4

2 files changed

Lines changed: 109 additions & 3 deletions

File tree

dlclivegui/services/dlc_processor.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import numpy as np
1616
from PySide6.QtCore import QObject, Signal
1717

18-
from dlclivegui.config import DLCProcessorSettings
18+
from dlclivegui.config import DLCProcessorSettings, ModelType
1919
from dlclivegui.processors.processor_utils import instantiate_from_scan
2020
from dlclivegui.temp import Engine # type: ignore # TODO use main package enum when released
2121

@@ -37,6 +37,63 @@
3737
class PoseResult:
3838
pose: np.ndarray | None
3939
timestamp: float
40+
packet: "PosePacketV0 | None" = None
41+
42+
43+
@dataclass(slots=True, frozen=True)
44+
class PoseSource:
45+
backend: str # e.g. "DLCLive"
46+
model_type: ModelType | None = None
47+
48+
49+
@dataclass(slots=True, frozen=True)
50+
class PosePacketV0:
51+
schema_version: int = 0
52+
keypoints: np.ndarray | None = None
53+
keypoint_names: list[str] | None = None
54+
individual_ids: list[str] | None = None
55+
source: PoseSource = PoseSource(backend="DLCLive")
56+
raw: Any | None = None
57+
58+
59+
def validate_pose_array(pose: Any, *, source_backend: str = "DLCLive") -> np.ndarray:
60+
"""
61+
Validate pose output shape and dtype.
62+
63+
Accepted runner output shapes:
64+
- (K, 3): single-animal
65+
- (N, K, 3): multi-animal
66+
"""
67+
try:
68+
arr = np.asarray(pose)
69+
except Exception as exc:
70+
raise ValueError(
71+
f"{source_backend} returned an invalid pose output format: could not convert to array ({exc})"
72+
) from exc
73+
74+
if arr.ndim not in (2, 3):
75+
raise ValueError(
76+
f"{source_backend} returned an invalid pose output format: expected a 2D or 3D array, got ndim={arr.ndim}, shape={arr.shape!r}"
77+
)
78+
79+
if arr.shape[-1] != 3:
80+
raise ValueError(
81+
f"{source_backend} returned an invalid pose output format: expected last dimension size 3 (x, y, likelihood), got shape={arr.shape!r}"
82+
)
83+
84+
if arr.ndim == 2 and arr.shape[0] <= 0:
85+
raise ValueError(f"{source_backend} returned an invalid pose output format: expected at least one keypoint")
86+
if arr.ndim == 3 and (arr.shape[0] <= 0 or arr.shape[1] <= 0):
87+
raise ValueError(
88+
f"{source_backend} returned an invalid pose output format: expected at least one individual and one keypoint, got shape={arr.shape!r}"
89+
)
90+
91+
if not np.issubdtype(arr.dtype, np.number):
92+
raise ValueError(
93+
f"{source_backend} returned an invalid pose output format: expected numeric values, got dtype={arr.dtype}"
94+
)
95+
96+
return arr
4097

4198

4299
@dataclass
@@ -269,8 +326,17 @@ def _process_frame(
269326
# Time GPU inference (and processor overhead when present)
270327
with self._timed_processor() as proc_holder:
271328
inference_start = time.perf_counter()
272-
pose = self._dlc.get_pose(frame, frame_time=timestamp)
329+
raw_pose: Any = self._dlc.get_pose(frame, frame_time=timestamp)
273330
inference_time = time.perf_counter() - inference_start
331+
pose_arr: np.ndarray = validate_pose_array(raw_pose, source_backend="DLCLive")
332+
pose_packet = PosePacketV0(
333+
schema_version=0,
334+
keypoints=pose_arr,
335+
keypoint_names=None,
336+
individual_ids=None,
337+
source=PoseSource(backend="DLCLive", model_type=self._settings.model_type),
338+
raw=raw_pose,
339+
)
274340

275341
processor_overhead = 0.0
276342
gpu_inference_time = inference_time
@@ -280,7 +346,7 @@ def _process_frame(
280346

281347
# Emit pose (measure signal overhead)
282348
signal_start = time.perf_counter()
283-
self.pose_ready.emit(PoseResult(pose=pose, timestamp=timestamp))
349+
self.pose_ready.emit(PoseResult(pose=pose_packet.keypoints, timestamp=timestamp, packet=pose_packet))
284350
signal_time = time.perf_counter() - signal_start
285351

286352
end_ts = time.perf_counter()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
import pytest
3+
4+
from dlclivegui.services.dlc_processor import validate_pose_array
5+
6+
7+
@pytest.mark.unit
8+
def test_validate_pose_array_keeps_single_animal_shape():
9+
pose = np.ones((5, 3), dtype=np.float64)
10+
out = validate_pose_array(pose)
11+
assert out.shape == (5, 3)
12+
assert out.dtype == np.float64
13+
14+
15+
@pytest.mark.unit
16+
def test_validate_pose_array_accepts_multi_animal():
17+
pose = np.ones((2, 5, 3), dtype=np.float32)
18+
out = validate_pose_array(pose)
19+
assert out.shape == (2, 5, 3)
20+
21+
22+
@pytest.mark.unit
23+
@pytest.mark.parametrize(
24+
"bad_pose,expected",
25+
[
26+
(np.ones((5, 2), dtype=np.float32), "last dimension size 3"),
27+
(np.ones((2, 5, 4), dtype=np.float32), "last dimension size 3"),
28+
(np.ones((3,), dtype=np.float32), "expected a 2D or 3D array"),
29+
],
30+
)
31+
def test_validate_pose_array_rejects_invalid_shapes(bad_pose, expected):
32+
with pytest.raises(ValueError, match=expected):
33+
validate_pose_array(bad_pose)
34+
35+
36+
@pytest.mark.unit
37+
def test_validate_pose_array_rejects_non_numeric():
38+
pose = np.array([[["x", "y", "p"]]], dtype=object)
39+
with pytest.raises(ValueError, match="expected numeric values"):
40+
validate_pose_array(pose)

0 commit comments

Comments
 (0)