Skip to content

Commit 333f714

Browse files
committed
update runner: consider pretrained detectors (no weights in raw_data)
1 parent 01a2f2e commit 333f714

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

dlclive/pose_estimation_pytorch/runner.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,24 @@ def load_model(self) -> None:
268268
self.model = self.model.half()
269269

270270
self.detector = None
271-
if self.dynamic is None and raw_data.get("detector") is not None:
271+
detector_cfg = self.cfg.get("detector")
272+
has_detector_weights = raw_data.get("detector") is not None
273+
if detector_cfg is not None:
274+
detector_model_cfg = detector_cfg["model"]
275+
uses_pretrained = (
276+
detector_model_cfg.get("pretrained", False)
277+
or detector_model_cfg.get("weights") is not None
278+
)
279+
else:
280+
uses_pretrained = False
281+
282+
if self.dynamic is None and (has_detector_weights or uses_pretrained):
272283
self.detector = models.DETECTORS.build(self.cfg["detector"]["model"])
273284
self.detector.to(self.device)
274-
self.detector.load_state_dict(raw_data["detector"])
285+
286+
if has_detector_weights:
287+
self.detector.load_state_dict(raw_data["detector"])
288+
275289
self.detector.eval()
276290
if self.precision == "FP16":
277291
self.detector = self.detector.half()
@@ -281,7 +295,8 @@ def load_model(self) -> None:
281295
self.top_down_config.read_config(self.cfg)
282296

283297
detector_transforms = [v2.ToDtype(torch.float32, scale=True)]
284-
if self.cfg["detector"]["data"]["inference"].get("normalize_images", False):
298+
detector_data_cfg = detector_cfg.get("data", {}).get("inference", {})
299+
if detector_data_cfg.get("normalize_images", False):
285300
detector_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
286301
self.detector_transform = v2.Compose(detector_transforms)
287302

0 commit comments

Comments
 (0)