Skip to content

Commit f3c8fba

Browse files
type hints for benchmark (and also some missing ones in dlclive) (#117)
Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu>
1 parent 5422e6a commit f3c8fba

2 files changed

Lines changed: 36 additions & 36 deletions

File tree

dlclive/benchmark.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import sys
1313
import warnings
1414
import subprocess
15-
import typing
15+
from typing import List, Optional, Tuple, Union
1616
import pickle
1717
import colorcet as cc
1818
from PIL import ImageColor
@@ -144,22 +144,22 @@ def get_system_info() -> dict:
144144

145145

146146
def benchmark(
147-
model_path,
148-
video_path,
149-
tf_config=None,
150-
resize=None,
151-
pixels=None,
152-
cropping=None,
153-
dynamic=(False, 0.5, 10),
154-
n_frames=1000,
155-
print_rate=False,
156-
display=False,
157-
pcutoff=0.0,
158-
display_radius=3,
159-
cmap="bmy",
160-
save_poses=False,
161-
save_video=False,
162-
output=None,
147+
model_path: str,
148+
video_path: str,
149+
tf_config: Optional[tf.ConfigProto] = None,
150+
resize: Optional[float] = None,
151+
pixels: Optional[int] = None,
152+
cropping: Optional[List[int]] = None,
153+
dynamic: Tuple[bool, float, int] = (False, 0.5, 10),
154+
n_frames: int = 1000,
155+
print_rate: bool = False,
156+
display: bool = False,
157+
pcutoff: float = 0.0,
158+
display_radius: int = 3,
159+
cmap: str = "bmy",
160+
save_poses: bool = False,
161+
save_video: bool = False,
162+
output: Optional[str] = None,
163163
) -> typing.Tuple[np.ndarray, tuple, bool, dict]:
164164
""" Analyze DeepLabCut-live exported model on a video:
165165
Calculate inference time,
@@ -512,22 +512,22 @@ def save_inf_times(
512512

513513

514514
def benchmark_videos(
515-
model_path,
516-
video_path,
517-
output=None,
518-
n_frames=1000,
519-
tf_config=None,
520-
resize=None,
521-
pixels=None,
522-
cropping=None,
523-
dynamic=(False, 0.5, 10),
524-
print_rate=False,
525-
display=False,
526-
pcutoff=0.5,
527-
display_radius=3,
528-
cmap="bmy",
529-
save_poses=False,
530-
save_video=False,
515+
model_path: str,
516+
video_path: Union[str, List[str]],
517+
output: Optional[str] = None,
518+
n_frames: int = 1000,
519+
tf_config: Optional[tf.ConfigProto] = None,
520+
resize: Optional[Union[float, List[float]]] = None,
521+
pixels: Optional[Union[int, List[int]]] = None,
522+
cropping: Optional[List[int]] = None,
523+
dynamic: Tuple[bool, float, int] = (False, 0.5, 10),
524+
print_rate: bool = False,
525+
display: bool = False,
526+
pcutoff: float = 0.5,
527+
display_radius: int = 3,
528+
cmap: str = "bmy",
529+
save_poses: bool = False,
530+
save_video: bool = False,
531531
):
532532
"""Analyze videos using DeepLabCut-live exported models.
533533
Analyze multiple videos and/or multiple options for the size of the video

dlclive/dlclive.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def parameterization(self) -> dict:
181181
"""
182182
return {param: getattr(self, param) for param in self.PARAMETERS}
183183

184-
def process_frame(self, frame):
184+
def process_frame(self, frame: np.ndarray) -> np.ndarray:
185185
"""
186186
Crops an image according to the object's cropping and dynamic properties.
187187
@@ -237,7 +237,7 @@ def process_frame(self, frame):
237237

238238
return frame
239239

240-
def init_inference(self, frame=None, **kwargs):
240+
def init_inference(self, frame=None, **kwargs) -> np.ndarray:
241241
"""
242242
Load model and perform inference on first frame -- the first inference is usually very slow.
243243
@@ -376,7 +376,7 @@ def init_inference(self, frame=None, **kwargs):
376376

377377
return pose
378378

379-
def get_pose(self, frame=None, **kwargs):
379+
def get_pose(self, frame=None, **kwargs) -> np.ndarray:
380380
"""
381381
Get the pose of an image
382382

0 commit comments

Comments
 (0)