Skip to content

Commit 287cf2f

Browse files
committed
Refactor assembler tests to use fixtures
Move test helpers into tests/conftest.py and introduce a suite of reusable pytest fixtures for assembler testing (headless_display_env, assembler graph/paf fixtures, scene factories, assembler/assembly/joint/link factories, and various canned dataset variants). Refactor tests/tests_core/test_assembler.py and tests/tests_core/test_assembly.py to consume the new fixtures, remove duplicated setup code, and simplify assertions. Also adjust serialization test filenames and tidy up identity/affinity-related test logic.
1 parent 7f67734 commit 287cf2f

3 files changed

Lines changed: 387 additions & 352 deletions

File tree

tests/conftest.py

Lines changed: 265 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,279 @@
1+
from __future__ import annotations
2+
3+
import copy
4+
from collections.abc import Callable
5+
from typing import Any
6+
7+
import numpy as np
18
import pytest
29

10+
from dlclive.core.inferenceutils import Assembler
11+
312

13+
# --------------------------------------------------------------------------------------
14+
# Headless display fixture
15+
# --------------------------------------------------------------------------------------
416
@pytest.fixture
517
def headless_display_env(monkeypatch):
6-
# Import module under test
18+
"""Patch dlclive.display so tkinter is replaced with fake, non-GUI-safe objects."""
719
from test_display import FakeLabel, FakePhotoImage, FakeTk
820

921
import dlclive.display as display_mod
1022

11-
# Force tkinter availability and patch UI components
12-
monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", True, raising=False)
13-
monkeypatch.setattr(display_mod, "Tk", FakeTk, raising=False)
14-
monkeypatch.setattr(display_mod, "Label", FakeLabel, raising=False)
23+
monkeypatch.setattr(display_mod, "_TKINTER_AVAILABLE", True)
24+
monkeypatch.setattr(display_mod, "Tk", FakeTk)
25+
monkeypatch.setattr(display_mod, "Label", FakeLabel)
1526

16-
# Patch ImageTk.PhotoImage
1727
class FakeImageTkModule:
1828
PhotoImage = FakePhotoImage
1929

20-
monkeypatch.setattr(display_mod, "ImageTk", FakeImageTkModule, raising=False)
21-
30+
monkeypatch.setattr(display_mod, "ImageTk", FakeImageTkModule)
2231
return display_mod
32+
33+
34+
# --------------------------------------------------------------------------------------
35+
# Assembler/assembly test fixtures
36+
# --------------------------------------------------------------------------------------
37+
@pytest.fixture
38+
def assembler_graph_and_pafs() -> tuple[list[tuple[int, int]], list[int]]:
39+
"""Standard 2‑joint graph used throughout the test suite."""
40+
return ([(0, 1)], [0])
41+
42+
43+
@pytest.fixture
44+
def make_assembler_metadata() -> Callable[..., dict[str, Any]]:
45+
"""Return a factory that builds minimal Assembler metadata dictionaries."""
46+
47+
def _factory(graph, paf_inds, n_bodyparts, frame_keys):
48+
return {
49+
"metadata": {
50+
"all_joints_names": [f"b{i}" for i in range(n_bodyparts)],
51+
"PAFgraph": graph,
52+
"PAFinds": paf_inds,
53+
},
54+
**{k: {} for k in frame_keys},
55+
}
56+
57+
return _factory
58+
59+
60+
@pytest.fixture
61+
def make_assembler_frame() -> Callable[..., dict[str, Any]]:
62+
"""Return a factory that builds a frame dict compatible with _flatten_detections."""
63+
64+
def _factory(
65+
coordinates_per_label,
66+
confidence_per_label,
67+
identity_per_label=None,
68+
costs=None,
69+
):
70+
frame = {
71+
"coordinates": [coordinates_per_label],
72+
"confidence": confidence_per_label,
73+
"costs": costs or {},
74+
}
75+
if identity_per_label is not None:
76+
frame["identity"] = identity_per_label
77+
return frame
78+
79+
return _factory
80+
81+
82+
@pytest.fixture
83+
def simple_two_label_scene(make_assembler_frame) -> dict[str, Any]:
84+
"""Deterministic scene with predictable affinities for testing."""
85+
coords0 = np.array([[0.0, 0.0], [100.0, 100.0]])
86+
coords1 = np.array([[5.0, 0.0], [110.0, 100.0]])
87+
conf0 = np.array([0.9, 0.6])
88+
conf1 = np.array([0.8, 0.7])
89+
90+
aff = np.array([[0.95, 0.1], [0.05, 0.9]])
91+
92+
lens = np.array(
93+
[
94+
[np.hypot(*(coords1[0] - coords0[0])), np.hypot(*(coords1[1] - coords0[0]))],
95+
[np.hypot(*(coords1[0] - coords0[1])), np.hypot(*(coords1[1] - coords0[1]))],
96+
]
97+
)
98+
99+
return make_assembler_frame(
100+
coordinates_per_label=[coords0, coords1],
101+
confidence_per_label=[conf0, conf1],
102+
identity_per_label=None,
103+
costs={0: {"distance": lens, "m1": aff}},
104+
)
105+
106+
107+
@pytest.fixture
108+
def scene_copy(simple_two_label_scene) -> dict[str, Any]:
109+
"""Return a deep copy of the simple_two_label_scene fixture."""
110+
return copy.deepcopy(simple_two_label_scene)
111+
112+
113+
@pytest.fixture
114+
def assembler_data(
115+
assembler_graph_and_pafs,
116+
make_assembler_metadata,
117+
simple_two_label_scene,
118+
) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]:
119+
"""Full metadata + two identical frames ('0', '1')."""
120+
graph, paf_inds = assembler_graph_and_pafs
121+
data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"])
122+
data["0"] = simple_two_label_scene
123+
data["1"] = simple_two_label_scene
124+
return data, graph, paf_inds
125+
126+
127+
@pytest.fixture
128+
def assembler_data_single_frame(
129+
assembler_graph_and_pafs,
130+
make_assembler_metadata,
131+
simple_two_label_scene,
132+
) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]:
133+
"""Metadata + a single frame ('0'). Used by most tests."""
134+
graph, paf_inds = assembler_graph_and_pafs
135+
data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"])
136+
data["0"] = simple_two_label_scene
137+
return data, graph, paf_inds
138+
139+
140+
@pytest.fixture
141+
def assembler_data_two_frames_nudged(
142+
assembler_graph_and_pafs,
143+
make_assembler_metadata,
144+
simple_two_label_scene,
145+
) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]:
146+
"""Two frames where frame '1' is a nudged copy of frame '0'."""
147+
graph, paf_inds = assembler_graph_and_pafs
148+
data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0", "1"])
149+
150+
frame0 = simple_two_label_scene
151+
frame1 = copy.deepcopy(simple_two_label_scene)
152+
frame1["coordinates"][0][0] += np.array([[1.0, 0.0], [1.0, 0.0]])
153+
frame1["coordinates"][0][1] += np.array([[1.0, 0.0], [1.0, 0.0]])
154+
155+
data["0"] = frame0
156+
data["1"] = frame1
157+
return data, graph, paf_inds
158+
159+
160+
@pytest.fixture
161+
def assembler_data_no_detections(
162+
assembler_graph_and_pafs,
163+
make_assembler_metadata,
164+
make_assembler_frame,
165+
) -> tuple[dict[str, Any], list[tuple[int, int]], list[int]]:
166+
"""Metadata + a single frame ('0') with zero detections for both labels."""
167+
graph, paf_inds = assembler_graph_and_pafs
168+
data = make_assembler_metadata(graph, paf_inds, n_bodyparts=2, frame_keys=["0"])
169+
170+
frame = make_assembler_frame(
171+
coordinates_per_label=[np.zeros((0, 2)), np.zeros((0, 2))],
172+
confidence_per_label=[np.zeros((0,)), np.zeros((0,))],
173+
identity_per_label=None,
174+
costs={},
175+
)
176+
data["0"] = frame
177+
return data, graph, paf_inds
178+
179+
180+
@pytest.fixture
181+
def make_assembler() -> Callable[..., Assembler]:
182+
"""
183+
Factory to create an Assembler with sensible defaults for this test suite.
184+
Override any parameter per-test via kwargs.
185+
"""
186+
187+
def _factory(data: dict[str, Any], **overrides) -> Assembler:
188+
defaults = dict(
189+
max_n_individuals=2,
190+
n_multibodyparts=2,
191+
min_n_links=1,
192+
pcutoff=0.1,
193+
min_affinity=0.05,
194+
)
195+
defaults.update(overrides)
196+
return Assembler(data, **defaults)
197+
198+
return _factory
199+
200+
201+
# --------------------------------------------------------------------------------------
202+
# Assembly / Joint / Link test fixtures
203+
# --------------------------------------------------------------------------------------
204+
from dlclive.core.inferenceutils import Assembly, Joint, Link # noqa: E402
205+
206+
207+
@pytest.fixture
208+
def make_assembly() -> Callable[..., Assembly]:
209+
"""Factory to create an Assembly with the given size."""
210+
211+
def _factory(size: int) -> Assembly:
212+
return Assembly(size=size)
213+
214+
return _factory
215+
216+
217+
@pytest.fixture
218+
def make_joint() -> Callable[..., Joint]:
219+
"""Factory to create a Joint with sensible defaults."""
220+
221+
def _factory(
222+
pos=(0.0, 0.0),
223+
confidence: float = 1.0,
224+
label: int = 0,
225+
idx: int = 0,
226+
group: int = -1,
227+
) -> Joint:
228+
return Joint(pos=pos, confidence=confidence, label=label, idx=idx, group=group)
229+
230+
return _factory
231+
232+
233+
@pytest.fixture
234+
def make_link() -> Callable[..., Link]:
235+
"""Factory to create a Link between two joints."""
236+
237+
def _factory(j1: Joint, j2: Joint, affinity: float = 1.0) -> Link:
238+
return Link(j1, j2, affinity=affinity)
239+
240+
return _factory
241+
242+
243+
@pytest.fixture
244+
def two_overlap_assemblies(make_assembly) -> tuple[Assembly, Assembly]:
245+
"""Two assemblies with partial overlap used by intersection tests."""
246+
ass1 = make_assembly(2)
247+
ass1.data[0, :2] = [0, 0]
248+
ass1.data[1, :2] = [10, 10]
249+
ass1._visible.update({0, 1})
250+
251+
ass2 = make_assembly(2)
252+
ass2.data[0, :2] = [5, 5]
253+
ass2.data[1, :2] = [15, 15]
254+
ass2._visible.update({0, 1})
255+
return ass1, ass2
256+
257+
258+
@pytest.fixture
259+
def soft_identity_assembly(make_assembly) -> Assembly:
260+
"""Assembly configured for soft_identity tests."""
261+
assemb = make_assembly(3)
262+
assemb.data[:] = np.nan
263+
assemb.data[0] = [0, 0, 1.0, 0]
264+
assemb.data[1] = [5, 5, 0.5, 0]
265+
assemb.data[2] = [10, 10, 1.0, 1]
266+
assemb._visible = {0, 1, 2}
267+
return assemb
268+
269+
270+
@pytest.fixture
271+
def four_joint_chain(make_joint, make_link) -> tuple[Joint, Joint, Joint, Joint, Link, Link]:
272+
"""Four joints and two links: (0-1) and (2-3)."""
273+
j0 = make_joint((0, 0), 1.0, label=0, idx=10)
274+
j1 = make_joint((1, 0), 1.0, label=1, idx=11)
275+
j2 = make_joint((2, 0), 1.0, label=2, idx=12)
276+
j3 = make_joint((3, 0), 1.0, label=3, idx=13)
277+
l01 = make_link(j0, j1, affinity=0.5)
278+
l23 = make_link(j2, j3, affinity=0.8)
279+
return j0, j1, j2, j3, l01, l23

0 commit comments

Comments
 (0)