Source code for mmpose.codecs.spr
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from mmpose.registry import KEYPOINT_CODECS
from .base import BaseKeypointCodec
from .utils import (batch_heatmap_nms, generate_displacement_heatmap,
generate_gaussian_heatmaps, get_diagonal_lengths,
get_instance_root)
[docs]@KEYPOINT_CODECS.register_module()
class SPR(BaseKeypointCodec):
"""Encode/decode keypoints with Structured Pose Representation (SPR).
See the paper `Single-stage multi-person pose machines`_
by Nie et al (2017) for details
Note:
- instance number: N
- keypoint number: K
- keypoint dimension: D
- image size: [w, h]
- heatmap size: [W, H]
Encoded:
- heatmaps (np.ndarray): The generated heatmap in shape (1, H, W)
where [W, H] is the `heatmap_size`. If the keypoint heatmap is
generated together, the output heatmap shape is (K+1, H, W)
- heatmap_weights (np.ndarray): The target weights for heatmaps which
has same shape with heatmaps.
- displacements (np.ndarray): The dense keypoint displacement in
shape (K*2, H, W).
- displacement_weights (np.ndarray): The target weights for heatmaps
which has same shape with displacements.
Args:
input_size (tuple): Image size in [w, h]
heatmap_size (tuple): Heatmap size in [W, H]
sigma (float or tuple, optional): The sigma values of the Gaussian
heatmaps. If sigma is a tuple, it includes both sigmas for root
and keypoint heatmaps. ``None`` means the sigmas are computed
automatically from the heatmap size. Defaults to ``None``
generate_keypoint_heatmaps (bool): Whether to generate Gaussian
heatmaps for each keypoint. Defaults to ``False``
root_type (str): The method to generate the instance root. Options
are:
- ``'kpt_center'``: Average coordinate of all visible keypoints.
- ``'bbox_center'``: Center point of bounding boxes outlined by
all visible keypoints.
Defaults to ``'kpt_center'``
minimal_diagonal_length (int or float): The threshold of diagonal
length of instance bounding box. Small instances will not be
used in training. Defaults to 32
background_weight (float): Loss weight of background pixels.
Defaults to 0.1
decode_thr (float): The threshold of keypoint response value in
heatmaps. Defaults to 0.01
decode_nms_kernel (int): The kernel size of the NMS during decoding,
which should be an odd integer. Defaults to 5
decode_max_instances (int): The maximum number of instances
to decode. Defaults to 30
.. _`Single-stage multi-person pose machines`:
https://arxiv.org/abs/1908.09220
"""
field_mapping_table = dict(
heatmaps='heatmaps',
heatmap_weights='heatmap_weights',
displacements='displacements',
displacement_weights='displacement_weights',
)
def __init__(
self,
input_size: Tuple[int, int],
heatmap_size: Tuple[int, int],
sigma: Optional[Union[float, Tuple[float]]] = None,
generate_keypoint_heatmaps: bool = False,
root_type: str = 'kpt_center',
minimal_diagonal_length: Union[int, float] = 5,
background_weight: float = 0.1,
decode_nms_kernel: int = 5,
decode_max_instances: int = 30,
decode_thr: float = 0.01,
):
super().__init__()
self.input_size = input_size
self.heatmap_size = heatmap_size
self.generate_keypoint_heatmaps = generate_keypoint_heatmaps
self.root_type = root_type
self.minimal_diagonal_length = minimal_diagonal_length
self.background_weight = background_weight
self.decode_nms_kernel = decode_nms_kernel
self.decode_max_instances = decode_max_instances
self.decode_thr = decode_thr
self.scale_factor = (np.array(input_size) /
heatmap_size).astype(np.float32)
if sigma is None:
sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 32
if generate_keypoint_heatmaps:
# sigma for root heatmap and keypoint heatmaps
self.sigma = (sigma, sigma // 2)
else:
self.sigma = (sigma, )
else:
if not isinstance(sigma, (tuple, list)):
sigma = (sigma, )
if generate_keypoint_heatmaps:
assert len(sigma) == 2, 'sigma for keypoints must be given ' \
'if `generate_keypoint_heatmaps` ' \
'is True. e.g. sigma=(4, 2)'
self.sigma = sigma
def _get_heatmap_weights(self,
heatmaps,
fg_weight: float = 1,
bg_weight: float = 0):
"""Generate weight array for heatmaps.
Args:
heatmaps (np.ndarray): Root and keypoint (optional) heatmaps
fg_weight (float): Weight for foreground pixels. Defaults to 1.0
bg_weight (float): Weight for background pixels. Defaults to 0.0
Returns:
np.ndarray: Heatmap weight array in the same shape with heatmaps
"""
heatmap_weights = np.ones(heatmaps.shape, dtype=np.float32) * bg_weight
heatmap_weights[heatmaps > 0] = fg_weight
return heatmap_weights
[docs] def encode(self,
keypoints: np.ndarray,
keypoints_visible: Optional[np.ndarray] = None) -> dict:
"""Encode keypoints into root heatmaps and keypoint displacement
fields. Note that the original keypoint coordinates should be in the
input image space.
Args:
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
keypoints_visible (np.ndarray): Keypoint visibilities in shape
(N, K)
Returns:
dict:
- heatmaps (np.ndarray): The generated heatmap in shape
(1, H, W) where [W, H] is the `heatmap_size`. If keypoint
heatmaps are generated together, the shape is (K+1, H, W)
- heatmap_weights (np.ndarray): The pixel-wise weight for heatmaps
which has same shape with `heatmaps`
- displacements (np.ndarray): The generated displacement fields in
shape (K*D, H, W). The vector on each pixels represents the
displacement of keypoints belong to the associated instance
from this pixel.
- displacement_weights (np.ndarray): The pixel-wise weight for
displacements which has same shape with `displacements`
"""
if keypoints_visible is None:
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
# keypoint coordinates in heatmap
_keypoints = keypoints / self.scale_factor
# compute the root and scale of each instance
roots, roots_visible = get_instance_root(_keypoints, keypoints_visible,
self.root_type)
diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible)
# discard the small instances
roots_visible[diagonal_lengths < self.minimal_diagonal_length] = 0
# generate heatmaps
heatmaps, _ = generate_gaussian_heatmaps(
heatmap_size=self.heatmap_size,
keypoints=roots[:, None],
keypoints_visible=roots_visible[:, None],
sigma=self.sigma[0])
heatmap_weights = self._get_heatmap_weights(
heatmaps, bg_weight=self.background_weight)
if self.generate_keypoint_heatmaps:
keypoint_heatmaps, _ = generate_gaussian_heatmaps(
heatmap_size=self.heatmap_size,
keypoints=_keypoints,
keypoints_visible=keypoints_visible,
sigma=self.sigma[1])
keypoint_heatmaps_weights = self._get_heatmap_weights(
keypoint_heatmaps, bg_weight=self.background_weight)
heatmaps = np.concatenate((keypoint_heatmaps, heatmaps), axis=0)
heatmap_weights = np.concatenate(
(keypoint_heatmaps_weights, heatmap_weights), axis=0)
# generate displacements
displacements, displacement_weights = \
generate_displacement_heatmap(
self.heatmap_size,
_keypoints,
keypoints_visible,
roots,
roots_visible,
diagonal_lengths,
self.sigma[0],
)
encoded = dict(
heatmaps=heatmaps,
heatmap_weights=heatmap_weights,
displacements=displacements,
displacement_weights=displacement_weights)
return encoded
[docs] def decode(self, heatmaps: Tensor,
displacements: Tensor) -> Tuple[np.ndarray, np.ndarray]:
"""Decode the keypoint coordinates from heatmaps and displacements. The
decoded keypoint coordinates are in the input image space.
Args:
heatmaps (Tensor): Encoded root and keypoints (optional) heatmaps
in shape (1, H, W) or (K+1, H, W)
displacements (Tensor): Encoded keypoints displacement fields
in shape (K*D, H, W)
Returns:
tuple:
- keypoints (Tensor): Decoded keypoint coordinates in shape
(N, K, D)
- scores (tuple):
- root_scores (Tensor): The root scores in shape (N, )
- keypoint_scores (Tensor): The keypoint scores in
shape (N, K). If keypoint heatmaps are not generated,
`keypoint_scores` will be `None`
"""
# heatmaps, displacements = encoded
_k, h, w = displacements.shape
k = _k // 2
displacements = displacements.view(k, 2, h, w)
# convert displacements to a dense keypoint prediction
y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
regular_grid = torch.stack([x, y], dim=0).to(displacements)
posemaps = (regular_grid[None] + displacements).flatten(2)
# find local maximum on root heatmap
root_heatmap_peaks = batch_heatmap_nms(heatmaps[None, -1:],
self.decode_nms_kernel)
root_scores, pos_idx = root_heatmap_peaks.flatten().topk(
self.decode_max_instances)
mask = root_scores > self.decode_thr
root_scores, pos_idx = root_scores[mask], pos_idx[mask]
keypoints = posemaps[:, :, pos_idx].permute(2, 0, 1).contiguous()
if self.generate_keypoint_heatmaps and heatmaps.shape[0] == 1 + k:
# compute scores for each keypoint
keypoint_scores = self.get_keypoint_scores(heatmaps[:k], keypoints)
else:
keypoint_scores = None
keypoints = torch.cat([
kpt * self.scale_factor[i]
for i, kpt in enumerate(keypoints.split(1, -1))
],
dim=-1)
return keypoints, (root_scores, keypoint_scores)
[docs] def get_keypoint_scores(self, heatmaps: Tensor, keypoints: Tensor):
"""Calculate the keypoint scores with keypoints heatmaps and
coordinates.
Args:
heatmaps (Tensor): Keypoint heatmaps in shape (K, H, W)
keypoints (Tensor): Keypoint coordinates in shape (N, K, D)
Returns:
Tensor: Keypoint scores in [N, K]
"""
k, h, w = heatmaps.shape
keypoints = torch.stack((
keypoints[..., 0] / (w - 1) * 2 - 1,
keypoints[..., 1] / (h - 1) * 2 - 1,
),
dim=-1)
keypoints = keypoints.transpose(0, 1).unsqueeze(1).contiguous()
keypoint_scores = torch.nn.functional.grid_sample(
heatmaps.unsqueeze(1), keypoints,
padding_mode='border').view(k, -1).transpose(0, 1).contiguous()
return keypoint_scores