
mmpose.models.utils.misc 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial

import torch
from six.moves import map, zip

def multi_apply(func, *args, **kwargs):
    """Apply function to a list of arguments.

        This function applies the ``func`` to multiple inputs and
        map the multiple outputs of the ``func`` into different
        list. Each list contains the same type of outputs corresponding
        to different inputs.

        func (Function): A function that will be applied to a list of

        tuple(list): A tuple containing multiple list, each list contains
            a kind of returned results by the function
    pfunc = partial(func, **kwargs) if kwargs else func
    map_results = map(pfunc, *args)
    return tuple(map(list, zip(*map_results)))

[文档]def filter_scores_and_topk(scores, score_thr, topk, results=None): """Filter results using score threshold and topk candidates. Args: scores (Tensor): The scores, shape (num_bboxes, K). score_thr (float): The score filter threshold. topk (int): The number of topk candidates. results (dict or list or Tensor, Optional): The results to which the filtering rule is to be applied. The shape of each item is (num_bboxes, N). Returns: tuple: Filtered results - scores (Tensor): The scores after being filtered, \ shape (num_bboxes_filtered, ). - labels (Tensor): The class labels, shape \ (num_bboxes_filtered, ). - anchor_idxs (Tensor): The anchor indexes, shape \ (num_bboxes_filtered, ). - filtered_results (dict or list or Tensor, Optional): \ The filtered results. The shape of each item is \ (num_bboxes_filtered, N). """ valid_mask = scores > score_thr scores = scores[valid_mask] valid_idxs = torch.nonzero(valid_mask) num_topk = min(topk, valid_idxs.size(0)) # torch.sort is actually faster than .topk (at least on GPUs) scores, idxs = scores.sort(descending=True) scores = scores[:num_topk] topk_idxs = valid_idxs[idxs[:num_topk]] keep_idxs, labels = topk_idxs.unbind(dim=1) filtered_results = None if results is not None: if isinstance(results, dict): filtered_results = {k: v[keep_idxs] for k, v in results.items()} elif isinstance(results, list): filtered_results = [result[keep_idxs] for result in results] elif isinstance(results, torch.Tensor): filtered_results = results[keep_idxs] else: raise NotImplementedError(f'Only supports dict or list or Tensor, ' f'but get {type(results)}.') return scores, labels, keep_idxs, filtered_results