Skip to content

Commit 02821e1

Browse files
committed
Use Engine enum and improve model detection
Introduce a temporary Engine enum and expose it via dlclivegui.temp to standardize model engine types. Update DLCProcessorSettings to use the Engine enum for model_type. Improve model backend detection in the main window by calling DLCLiveProcessor.get_model_backend and raising a clear error if detection fails; tighten file dialog filters to focus on PyTorch models. Enhance engine detection to recognize .pth files. Minor import cleanup in dlc_processor and a clarifying test comment about timeouts to reduce flakiness in CI.
1 parent 45e3499 commit 02821e1

6 files changed

Lines changed: 24 additions & 10 deletions

File tree

dlclivegui/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from pydantic import BaseModel, Field, field_validator, model_validator
99

10+
from dlclivegui.temp import Engine
11+
1012
Rotation = Literal[0, 90, 180, 270]
1113
TileLayout = Literal["auto", "2x2", "1x4", "4x1"]
1214
Precision = Literal["FP32", "FP16"]
@@ -239,7 +241,7 @@ class DLCProcessorSettings(BaseModel):
239241
resize: float = Field(default=1.0, gt=0)
240242
precision: Precision = "FP32"
241243
additional_options: dict[str, Any] = Field(default_factory=dict)
242-
model_type: Literal["pytorch"] = "pytorch"
244+
model_type: Engine = "pytorch"
243245
single_animal: bool = True
244246

245247
@field_validator("dynamic", mode="before")

dlclivegui/gui/main_window.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,14 @@ def _parse_json(self, value: str) -> dict:
876876

877877
def _dlc_settings_from_ui(self) -> DLCProcessorSettings:
878878
model_path = self.model_path_edit.text().strip()
879+
try:
880+
DLCLiveProcessor.get_model_backend(model_path)
881+
except Exception as e:
882+
raise RuntimeError(
883+
"Could not determine model backend from path."
884+
"Please ensure the model file is valid and has an appropriate extension "
885+
"(.pt, .pth for PyTorch or model directory for TensorFlow)."
886+
) from e
879887
return DLCProcessorSettings(
880888
model_path=model_path,
881889
model_directory=self._config.dlc.model_directory, # Preserve from config
@@ -969,9 +977,9 @@ def _action_browse_model(self) -> None:
969977
dlg.setFileMode(QFileDialog.FileMode.ExistingFile)
970978
dlg.setNameFilters(
971979
[
972-
"Model files (*.pt *.pth *.pb)",
980+
"Model files (*.pt *.pth)",
973981
"PyTorch models (*.pt *.pth)",
974-
"TensorFlow models (*.pb)",
982+
# "TensorFlow models (*.pb)",
975983
"All files (*.*)",
976984
]
977985
)

dlclivegui/services/dlc_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# from dlclivegui.config import DLCProcessorSettings
2121
from dlclivegui.processors.processor_utils import instantiate_from_scan
22+
from dlclivegui.temp import Engine # type: ignore # TODO use main package enum when released
2223

2324
logger = logging.getLogger(__name__)
2425

@@ -29,7 +30,6 @@
2930
from dlclive import (
3031
DLCLive, # type: ignore
3132
)
32-
from dlclivegui.temp import Engine # type: ignore # TODO use main package one when released
3333
except Exception as e: # pragma: no cover - handled gracefully
3434
logger.error(f"dlclive package could not be imported: {e}")
3535
DLCLive = None # type: ignore[assignment]

dlclivegui/temp/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from engine import Engine # type: ignore
2+
3+
__all__ = ["Engine"]

dlclivegui/temp/engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from enum import Enum
22
from pathlib import Path
33

4+
45
class Engine(Enum):
56
TENSORFLOW = "tensorflow"
67
PYTORCH = "pytorch"
@@ -27,7 +28,7 @@ def from_model_path(cls, model_path: str | Path) -> "Engine":
2728
if has_cfg and has_pb:
2829
return cls.TENSORFLOW
2930
elif path.is_file():
30-
if path.suffix == ".pt":
31+
if path.suffix in (".pt", ".pth"):
3132
return cls.PYTORCH
3233

33-
raise ValueError(f"Could not determine engine from model path: {model_path}")
34+
raise ValueError(f"Could not determine engine from model path: {model_path}")

tests/services/test_dlc_processor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def test_worker_processes_frames(qtbot, monkeypatch_dlclive, settings_model):
6363
proc.enqueue_frame(frame, timestamp=2.0 + i)
6464
qtbot.wait(5) # ms
6565

66-
# NOTE @C-Achard this still fails randomly
67-
# the timeout has to be surprisingly large here
68-
# not sure if it's qtbot or threading scheduling delays
69-
# Should be fixed now.
66+
# NOTE @C-Achard The timeout here is intentionally large to account for potential
67+
# Qt event-loop and threading scheduling delays in CI environments.
68+
# This was previously flaky with a smaller timeout; increasing it should
69+
# keep the test stable.
7070
qtbot.waitUntil(lambda: proc.get_stats().frames_processed >= 3, timeout=3000)
7171

7272
finally:

0 commit comments

Comments
 (0)