Shortcuts

Source code for mmpose.structures.utils

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List

import cv2
import numpy as np
import torch
from mmengine.structures import InstanceData, PixelData
from mmengine.utils import is_list_of

from .bbox.transforms import get_warp_matrix
from .pose_data_sample import PoseDataSample


[docs]def merge_data_samples(data_samples: List[PoseDataSample]) -> PoseDataSample: """Merge the given data samples into a single data sample. This function can be used to merge the top-down predictions with bboxes from the same image. The merged data sample will contain all instances from the input data samples, and the identical metainfo with the first input data sample. Args: data_samples (List[:obj:`PoseDataSample`]): The data samples to merge Returns: PoseDataSample: The merged data sample. """ if not is_list_of(data_samples, PoseDataSample): raise ValueError('Invalid input type, should be a list of ' ':obj:`PoseDataSample`') if len(data_samples) == 0: warnings.warn('Try to merge an empty list of data samples.') return PoseDataSample() merged = PoseDataSample(metainfo=data_samples[0].metainfo) if 'gt_instances' in data_samples[0]: merged.gt_instances = InstanceData.cat( [d.gt_instances for d in data_samples]) if 'pred_instances' in data_samples[0]: merged.pred_instances = InstanceData.cat( [d.pred_instances for d in data_samples]) if 'pred_fields' in data_samples[0] and 'heatmaps' in data_samples[ 0].pred_fields: reverted_heatmaps = [ revert_heatmap(data_sample.pred_fields.heatmaps, data_sample.input_center, data_sample.input_scale, data_sample.ori_shape) for data_sample in data_samples ] merged_heatmaps = np.max(reverted_heatmaps, axis=0) pred_fields = PixelData() pred_fields.set_data(dict(heatmaps=merged_heatmaps)) merged.pred_fields = pred_fields if 'gt_fields' in data_samples[0] and 'heatmaps' in data_samples[ 0].gt_fields: reverted_heatmaps = [ revert_heatmap(data_sample.gt_fields.heatmaps, data_sample.input_center, data_sample.input_scale, data_sample.ori_shape) for data_sample in data_samples ] merged_heatmaps = np.max(reverted_heatmaps, axis=0) gt_fields = PixelData() gt_fields.set_data(dict(heatmaps=merged_heatmaps)) merged.gt_fields = gt_fields return merged
[docs]def revert_heatmap(heatmap, input_center, input_scale, img_shape): """Revert predicted heatmap on the original image. Args: heatmap (np.ndarray or torch.tensor): predicted heatmap. input_center (np.ndarray): bounding box center coordinate. input_scale (np.ndarray): bounding box scale. img_shape (tuple or list): size of original image. """ if torch.is_tensor(heatmap): heatmap = heatmap.cpu().detach().numpy() ndim = heatmap.ndim # [K, H, W] -> [H, W, K] if ndim == 3: heatmap = heatmap.transpose(1, 2, 0) hm_h, hm_w = heatmap.shape[:2] img_h, img_w = img_shape warp_mat = get_warp_matrix( input_center.reshape((2, )), input_scale.reshape((2, )), rot=0, output_size=(hm_w, hm_h), inv=True) heatmap = cv2.warpAffine( heatmap, warp_mat, (img_w, img_h), flags=cv2.INTER_LINEAR) # [H, W, K] -> [K, H, W] if ndim == 3: heatmap = heatmap.transpose(2, 0, 1) return heatmap
[docs]def split_instances(instances: InstanceData) -> List[InstanceData]: """Convert instances into a list where each element is a dict that contains information about one instance.""" results = [] # return an empty list if there is no instance detected by the model if instances is None: return results for i in range(len(instances.keypoints)): result = dict( keypoints=instances.keypoints[i].tolist(), keypoint_scores=instances.keypoint_scores[i].tolist(), ) if 'bboxes' in instances: result['bbox'] = instances.bboxes[i].tolist(), if 'bbox_scores' in instances: result['bbox_score'] = instances.bbox_scores[i] results.append(result) return results
Read the Docs v: latest
Versions
latest
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.