Shortcuts

Source code for mmpose.codecs.associative_embedding

# Copyright (c) OpenMMLab. All rights reserved.
from itertools import product
from typing import Any, List, Optional, Tuple

import numpy as np
import torch
from munkres import Munkres
from torch import Tensor

from mmpose.registry import KEYPOINT_CODECS
from mmpose.utils.tensor_utils import to_numpy
from .base import BaseKeypointCodec
from .utils import (batch_heatmap_nms, generate_gaussian_heatmaps,
                    generate_udp_gaussian_heatmaps, refine_keypoints,
                    refine_keypoints_dark_udp)


def _py_max_match(scores):
    """Apply munkres algorithm to get the best match.

    Args:
        scores(np.ndarray): cost matrix.

    Returns:
        np.ndarray: best match.
    """
    m = Munkres()
    tmp = m.compute(scores)
    tmp = np.array(tmp).astype(int)
    return tmp


def _group_keypoints_by_tags(vals: np.ndarray,
                             tags: np.ndarray,
                             locs: np.ndarray,
                             keypoint_order: List[int],
                             val_thr: float,
                             tag_thr: float = 1.0,
                             max_groups: Optional[int] = None) -> np.ndarray:
    """Group the keypoints by tags using Munkres algorithm.

    Note:

        - keypoint number: K
        - candidate number: M
        - tag dimenssion: L
        - coordinate dimension: D
        - group number: G

    Args:
        vals (np.ndarray): The heatmap response values of keypoints in shape
            (K, M)
        tags (np.ndarray): The tags of the keypoint candidates in shape
            (K, M, L)
        locs (np.ndarray): The locations of the keypoint candidates in shape
            (K, M, D)
        keypoint_order (List[int]): The grouping order of the keypoints.
            The groupping usually starts from a keypoints around the head and
            torso, and gruadually moves out to the limbs
        val_thr (float): The threshold of the keypoint response value
        tag_thr (float): The maximum allowed tag distance when matching a
            keypoint to a group. A keypoint with larger tag distance to any
            of the existing groups will initializes a new group
        max_groups (int, optional): The maximum group number. ``None`` means
            no limitation. Defaults to ``None``

    Returns:
        np.ndarray: grouped keypoints in shape (G, K, D+1), where the last
        dimenssion is the concatenated keypoint coordinates and scores.
    """

    tag_k, loc_k, val_k = tags, locs, vals
    K, M, D = locs.shape
    assert vals.shape == tags.shape[:2] == (K, M)
    assert len(keypoint_order) == K

    default_ = np.zeros((K, 3 + tag_k.shape[2]), dtype=np.float32)

    joint_dict = {}
    tag_dict = {}
    for i in range(K):
        idx = keypoint_order[i]

        tags = tag_k[idx]
        joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1)
        mask = joints[:, 2] > val_thr
        tags = tags[mask]  # shape: [M, L]
        joints = joints[mask]  # shape: [M, 3 + L], 3: x, y, val

        if joints.shape[0] == 0:
            continue

        if i == 0 or len(joint_dict) == 0:
            for tag, joint in zip(tags, joints):
                key = tag[0]
                joint_dict.setdefault(key, np.copy(default_))[idx] = joint
                tag_dict[key] = [tag]
        else:
            # shape: [M]
            grouped_keys = list(joint_dict.keys())
            # shape: [M, L]
            grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys]

            # shape: [M, M, L]
            diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :]
            # shape: [M, M]
            diff_normed = np.linalg.norm(diff, ord=2, axis=2)
            diff_saved = np.copy(diff_normed)
            diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3]

            num_added = diff.shape[0]
            num_grouped = diff.shape[1]

            if num_added > num_grouped:
                diff_normed = np.concatenate(
                    (diff_normed,
                     np.zeros((num_added, num_added - num_grouped),
                              dtype=np.float32) + 1e10),
                    axis=1)

            pairs = _py_max_match(diff_normed)
            for row, col in pairs:
                if (row < num_added and col < num_grouped
                        and diff_saved[row][col] < tag_thr):
                    key = grouped_keys[col]
                    joint_dict[key][idx] = joints[row]
                    tag_dict[key].append(tags[row])
                else:
                    key = tags[row][0]
                    joint_dict.setdefault(key, np.copy(default_))[idx] = \
                        joints[row]
                    tag_dict[key] = [tags[row]]

    joint_dict_keys = list(joint_dict.keys())[:max_groups]

    if joint_dict_keys:
        results = np.array([joint_dict[i]
                            for i in joint_dict_keys]).astype(np.float32)
        results = results[..., :D + 1]
    else:
        results = np.empty((0, K, D + 1), dtype=np.float32)
    return results


[docs]@KEYPOINT_CODECS.register_module() class AssociativeEmbedding(BaseKeypointCodec): """Encode/decode keypoints with the method introduced in "Associative Embedding". This is an asymmetric codec, where the keypoints are represented as gaussian heatmaps and position indices during encoding, and restored from predicted heatmaps and group tags. See the paper `Associative Embedding: End-to-End Learning for Joint Detection and Grouping`_ by Newell et al (2017) for details Note: - instance number: N - keypoint number: K - keypoint dimension: D - embedding tag dimension: L - image size: [w, h] - heatmap size: [W, H] Encoded: - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) where [W, H] is the `heatmap_size` - keypoint_indices (np.ndarray): The keypoint position indices in shape (N, K, 2). Each keypoint's index is [i, v], where i is the position index in the heatmap (:math:`i=y*w+x`) and v is the visibility - keypoint_weights (np.ndarray): The target weights in shape (N, K) Args: input_size (tuple): Image size in [w, h] heatmap_size (tuple): Heatmap size in [W, H] sigma (float): The sigma value of the Gaussian heatmap use_udp (bool): Whether use unbiased data processing. See `UDP (CVPR 2020)`_ for details. Defaults to ``False`` decode_keypoint_order (List[int]): The grouping order of the keypoint indices. The groupping usually starts from a keypoints around the head and torso, and gruadually moves out to the limbs decode_keypoint_thr (float): The threshold of keypoint response value in heatmaps. Defaults to 0.1 decode_tag_thr (float): The maximum allowed tag distance when matching a keypoint to a group. A keypoint with larger tag distance to any of the existing groups will initializes a new group. Defaults to 1.0 decode_nms_kernel (int): The kernel size of the NMS during decoding, which should be an odd integer. Defaults to 5 decode_gaussian_kernel (int): The kernel size of the Gaussian blur during decoding, which should be an odd integer. It is only used when ``self.use_udp==True``. Defaults to 3 decode_topk (int): The number top-k candidates of each keypoints that will be retrieved from the heatmaps during dedocding. Defaults to 20 decode_max_instances (int, optional): The maximum number of instances to decode. ``None`` means no limitation to the instance number. Defaults to ``None`` .. _`Associative Embedding: End-to-End Learning for Joint Detection and Grouping`: https://arxiv.org/abs/1611.05424 .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 """ def __init__( self, input_size: Tuple[int, int], heatmap_size: Tuple[int, int], sigma: Optional[float] = None, use_udp: bool = False, decode_keypoint_order: List[int] = [], decode_nms_kernel: int = 5, decode_gaussian_kernel: int = 3, decode_keypoint_thr: float = 0.1, decode_tag_thr: float = 1.0, decode_topk: int = 30, decode_center_shift=0.0, decode_max_instances: Optional[int] = None, ) -> None: super().__init__() self.input_size = input_size self.heatmap_size = heatmap_size self.use_udp = use_udp self.decode_nms_kernel = decode_nms_kernel self.decode_gaussian_kernel = decode_gaussian_kernel self.decode_keypoint_thr = decode_keypoint_thr self.decode_tag_thr = decode_tag_thr self.decode_topk = decode_topk self.decode_center_shift = decode_center_shift self.decode_max_instances = decode_max_instances self.decode_keypoint_order = decode_keypoint_order.copy() if self.use_udp: self.scale_factor = ((np.array(input_size) - 1) / (np.array(heatmap_size) - 1)).astype( np.float32) else: self.scale_factor = (np.array(input_size) / heatmap_size).astype(np.float32) if sigma is None: sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 64 self.sigma = sigma
[docs] def encode( self, keypoints: np.ndarray, keypoints_visible: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Encode keypoints into heatmaps and position indices. Note that the original keypoint coordinates should be in the input image space. Args: keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) keypoints_visible (np.ndarray): Keypoint visibilities in shape (N, K) Returns: dict: - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) where [W, H] is the `heatmap_size` - keypoint_indices (np.ndarray): The keypoint position indices in shape (N, K, 2). Each keypoint's index is [i, v], where i is the position index in the heatmap (:math:`i=y*w+x`) and v is the visibility - keypoint_weights (np.ndarray): The target weights in shape (N, K) """ if keypoints_visible is None: keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) # keypoint coordinates in heatmap _keypoints = keypoints / self.scale_factor if self.use_udp: heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps( heatmap_size=self.heatmap_size, keypoints=_keypoints, keypoints_visible=keypoints_visible, sigma=self.sigma) else: heatmaps, keypoint_weights = generate_gaussian_heatmaps( heatmap_size=self.heatmap_size, keypoints=_keypoints, keypoints_visible=keypoints_visible, sigma=self.sigma) keypoint_indices = self._encode_keypoint_indices( heatmap_size=self.heatmap_size, keypoints=_keypoints, keypoints_visible=keypoints_visible) encoded = dict( heatmaps=heatmaps, keypoint_indices=keypoint_indices, keypoint_weights=keypoint_weights) return encoded
def _encode_keypoint_indices(self, heatmap_size: Tuple[int, int], keypoints: np.ndarray, keypoints_visible: np.ndarray) -> np.ndarray: w, h = heatmap_size N, K, _ = keypoints.shape keypoint_indices = np.zeros((N, K, 2), dtype=np.int64) for n, k in product(range(N), range(K)): x, y = (keypoints[n, k] + 0.5).astype(np.int64) index = y * w + x vis = (keypoints_visible[n, k] > 0.5 and 0 <= x < w and 0 <= y < h) keypoint_indices[n, k] = [index, vis] return keypoint_indices
[docs] def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]: raise NotImplementedError()
def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor, k: int): """Get top-k response values from the heatmaps and corresponding tag values from the tagging heatmaps. Args: batch_heatmaps (Tensor): Keypoint detection heatmaps in shape (B, K, H, W) batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where the tag dim C is 2*K when using flip testing, or K otherwise k (int): The number of top responses to get Returns: tuple: - topk_vals (Tensor): Top-k response values of each heatmap in shape (B, K, Topk) - topk_tags (Tensor): The corresponding embedding tags of the top-k responses, in shape (B, K, Topk, L) - topk_locs (Tensor): The location of the top-k responses in each heatmap, in shape (B, K, Topk, 2) where last dimension represents x and y coordinates """ B, K, H, W = batch_heatmaps.shape L = batch_tags.shape[1] // K # shape of topk_val, top_indices: (B, K, TopK) topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk( k, dim=-1) topk_tags_per_kpts = [ torch.gather(_tag, dim=2, index=topk_indices) for _tag in torch.unbind(batch_tags.view(B, L, K, H * W), dim=1) ] topk_tags = torch.stack(topk_tags_per_kpts, dim=-1) # (B, K, TopK, L) topk_locs = torch.stack([topk_indices % W, topk_indices // W], dim=-1) # (B, K, TopK, 2) return topk_vals, topk_tags, topk_locs def _group_keypoints(self, batch_vals: np.ndarray, batch_tags: np.ndarray, batch_locs: np.ndarray): """Group keypoints into groups (each represents an instance) by tags. Args: batch_vals (Tensor): Heatmap response values of keypoint candidates in shape (B, K, Topk) batch_tags (Tensor): Tags of keypoint candidates in shape (B, K, Topk, L) batch_locs (Tensor): Locations of keypoint candidates in shape (B, K, Topk, 2) Returns: List[np.ndarray]: Grouping results of a batch, each element is a np.ndarray (in shape [N, K, D+1]) that contains the groups detected in an image, including both keypoint coordinates and scores. """ def _group_func(inputs: Tuple): vals, tags, locs = inputs return _group_keypoints_by_tags( vals, tags, locs, keypoint_order=self.decode_keypoint_order, val_thr=self.decode_keypoint_thr, tag_thr=self.decode_tag_thr, max_groups=self.decode_max_instances) _results = map(_group_func, zip(batch_vals, batch_tags, batch_locs)) results = list(_results) return results def _fill_missing_keypoints(self, keypoints: np.ndarray, keypoint_scores: np.ndarray, heatmaps: np.ndarray, tags: np.ndarray): """Fill the missing keypoints in the initial predictions. Args: keypoints (np.ndarray): Keypoint predictions in shape (N, K, D) keypoint_scores (np.ndarray): Keypint score predictions in shape (N, K), in which 0 means the corresponding keypoint is missing in the initial prediction heatmaps (np.ndarry): Heatmaps in shape (K, H, W) tags (np.ndarray): Tagging heatmaps in shape (C, H, W) where C=L*K Returns: tuple: - keypoints (np.ndarray): Keypoint predictions with missing ones filled - keypoint_scores (np.ndarray): Keypoint score predictions with missing ones filled """ N, K = keypoints.shape[:2] H, W = heatmaps.shape[1:] L = tags.shape[0] // K keypoint_tags = [tags[k::K] for k in range(K)] for n in range(N): # Calculate the instance tag (mean tag of detected keypoints) _tag = [] for k in range(K): if keypoint_scores[n, k] > 0: x, y = keypoints[n, k, :2].astype(np.int64) x = np.clip(x, 0, W - 1) y = np.clip(y, 0, H - 1) _tag.append(keypoint_tags[k][:, y, x]) tag = np.mean(_tag, axis=0) tag = tag.reshape(L, 1, 1) # Search maximum response of the missing keypoints for k in range(K): if keypoint_scores[n, k] > 0: continue dist_map = np.linalg.norm( keypoint_tags[k] - tag, ord=2, axis=0) cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W)) keypoints[n, k] = [x, y] keypoint_scores[n, k] = heatmaps[k, y, x] return keypoints, keypoint_scores
[docs] def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor ) -> Tuple[List[np.ndarray], List[np.ndarray]]: """Decode the keypoint coordinates from a batch of heatmaps and tagging heatmaps. The decoded keypoint coordinates are in the input image space. Args: batch_heatmaps (Tensor): Keypoint detection heatmaps in shape (B, K, H, W) batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where :math:`C=L*K` Returns: tuple: - batch_keypoints (List[np.ndarray]): Decoded keypoint coordinates of the batch, each is in shape (N, K, D) - batch_scores (List[np.ndarray]): Decoded keypoint scores of the batch, each is in shape (N, K). It usually represents the confidience of the keypoint prediction """ B, _, H, W = batch_heatmaps.shape assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), ( f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and ' f'tagging map ({batch_tags.shape})') # Heatmap NMS batch_heatmaps_peak = batch_heatmap_nms(batch_heatmaps, self.decode_nms_kernel) # Get top-k in each heatmap and and convert to numpy batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy( self._get_batch_topk( batch_heatmaps_peak, batch_tags, k=self.decode_topk)) # Group keypoint candidates into groups (instances) batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags, batch_topk_locs) # Convert to numpy batch_heatmaps_np = to_numpy(batch_heatmaps) batch_tags_np = to_numpy(batch_tags) # Refine the keypoint prediction batch_keypoints = [] batch_keypoint_scores = [] batch_instance_scores = [] for i, (groups, heatmaps, tags) in enumerate( zip(batch_groups, batch_heatmaps_np, batch_tags_np)): keypoints, scores = groups[..., :-1], groups[..., -1] instance_scores = scores.mean(axis=-1) if keypoints.size > 0: # refine keypoint coordinates according to heatmap distribution if self.use_udp: keypoints = refine_keypoints_dark_udp( keypoints, heatmaps, blur_kernel_size=self.decode_gaussian_kernel) else: keypoints = refine_keypoints(keypoints, heatmaps) keypoints += self.decode_center_shift * \ (scores > 0).astype(keypoints.dtype)[..., None] # identify missing keypoints keypoints, scores = self._fill_missing_keypoints( keypoints, scores, heatmaps, tags) batch_keypoints.append(keypoints) batch_keypoint_scores.append(scores) batch_instance_scores.append(instance_scores) # restore keypoint scale batch_keypoints = [ kpts * self.scale_factor for kpts in batch_keypoints ] return batch_keypoints, batch_keypoint_scores, batch_instance_scores
Read the Docs v: latest
Versions
latest
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.