Skip to content

Commit a096c94

Browse files
authored
SimCCPredictor: add visilibity computation (#131)
1 parent f8fb374 commit a096c94

1 file changed

Lines changed: 14 additions & 2 deletions

File tree

  • dlclive/pose_estimation_pytorch/models/predictors

dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,33 @@ class SimCCPredictor(BasePredictor):
4242
def __init__(
4343
self,
4444
simcc_split_ratio: float = 2.0,
45-
apply_softmax: bool = False,
4645
normalize_outputs: bool = False,
46+
apply_softmax: bool = True,
47+
sigma: float | int | tuple[float, ...] = 6.0,
48+
decode_beta: float = 150.0,
4749
) -> None:
4850
super().__init__()
4951
self.simcc_split_ratio = simcc_split_ratio
50-
self.apply_softmax = apply_softmax
5152
self.normalize_outputs = normalize_outputs
53+
self.apply_softmax = apply_softmax
54+
55+
if isinstance(sigma, (float, int)):
56+
self.sigma = np.array([sigma, sigma])
57+
else:
58+
self.sigma = np.array(sigma)
59+
self.decode_beta = decode_beta
5260

5361
def forward(
5462
self, stride: float, outputs: dict[str, torch.Tensor]
5563
) -> dict[str, torch.Tensor]:
5664
x, y = outputs["x"].detach(), outputs["y"].detach()
65+
5766
if self.normalize_outputs:
5867
x = get_simcc_normalized(x)
5968
y = get_simcc_normalized(y)
69+
else:
70+
x = x * (self.sigma[0] * self.decode_beta)
71+
y = y * (self.sigma[1] * self.decode_beta)
6072

6173
keypoints, scores = get_simcc_maximum(
6274
x.cpu().numpy(), y.cpu().numpy(), self.apply_softmax

0 commit comments

Comments
 (0)