mmpose.models.utils.misc 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
import torch
from six.moves import map, zip
def multi_apply(func, *args, **kwargs):
"""Apply function to a list of arguments.
Note:
This function applies the ``func`` to multiple inputs and
map the multiple outputs of the ``func`` into different
list. Each list contains the same type of outputs corresponding
to different inputs.
Args:
func (Function): A function that will be applied to a list of
arguments
Returns:
tuple(list): A tuple containing multiple list, each list contains
a kind of returned results by the function
"""
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
[文档]def filter_scores_and_topk(scores, score_thr, topk, results=None):
"""Filter results using score threshold and topk candidates.
Args:
scores (Tensor): The scores, shape (num_bboxes, K).
score_thr (float): The score filter threshold.
topk (int): The number of topk candidates.
results (dict or list or Tensor, Optional): The results to
which the filtering rule is to be applied. The shape
of each item is (num_bboxes, N).
Returns:
tuple: Filtered results
- scores (Tensor): The scores after being filtered, \
shape (num_bboxes_filtered, ).
- labels (Tensor): The class labels, shape \
(num_bboxes_filtered, ).
- anchor_idxs (Tensor): The anchor indexes, shape \
(num_bboxes_filtered, ).
- filtered_results (dict or list or Tensor, Optional): \
The filtered results. The shape of each item is \
(num_bboxes_filtered, N).
"""
valid_mask = scores > score_thr
scores = scores[valid_mask]
valid_idxs = torch.nonzero(valid_mask)
num_topk = min(topk, valid_idxs.size(0))
# torch.sort is actually faster than .topk (at least on GPUs)
scores, idxs = scores.sort(descending=True)
scores = scores[:num_topk]
topk_idxs = valid_idxs[idxs[:num_topk]]
keep_idxs, labels = topk_idxs.unbind(dim=1)
filtered_results = None
if results is not None:
if isinstance(results, dict):
filtered_results = {k: v[keep_idxs] for k, v in results.items()}
elif isinstance(results, list):
filtered_results = [result[keep_idxs] for result in results]
elif isinstance(results, torch.Tensor):
filtered_results = results[keep_idxs]
else:
raise NotImplementedError(f'Only supports dict or list or Tensor, '
f'but get {type(results)}.')
return scores, labels, keep_idxs, filtered_results