mmpose.models.pose_estimators.topdown 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from itertools import zip_longest
from typing import Optional

from torch import Tensor

from mmpose.registry import MODELS
from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType,
                                 OptMultiConfig, PixelDataList, SampleList)
from .base import BasePoseEstimator

[文档]@MODELS.register_module() class TopdownPoseEstimator(BasePoseEstimator): """Base class for top-down pose estimators. Args: backbone (dict): The backbone config neck (dict, optional): The neck config. Defaults to ``None`` head (dict, optional): The head config. Defaults to ``None`` train_cfg (dict, optional): The runtime config for training process. Defaults to ``None`` test_cfg (dict, optional): The runtime config for testing process. Defaults to ``None`` data_preprocessor (dict, optional): The data preprocessing config to build the instance of :class:`BaseDataPreprocessor`. Defaults to ``None`` init_cfg (dict, optional): The config to control the initialization. Defaults to ``None`` metainfo (dict): Meta information for dataset, such as keypoints definition and properties. If set, the metainfo of the input data batch will be overridden. For more details, please refer to prepare_datasets.html#create-a-custom-dataset-info- config-file-for-the-dataset. Defaults to ``None`` """ def __init__(self, backbone: ConfigType, neck: OptConfigType = None, head: OptConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None, metainfo: Optional[dict] = None): super().__init__( backbone=backbone, neck=neck, head=head, train_cfg=train_cfg, test_cfg=test_cfg, data_preprocessor=data_preprocessor, init_cfg=init_cfg, metainfo=metainfo)
[文档] def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: """Calculate losses from a batch of inputs and data samples. Args: inputs (Tensor): Inputs with shape (N, C, H, W). data_samples (List[:obj:`PoseDataSample`]): The batch data samples. Returns: dict: A dictionary of losses. """ feats = self.extract_feat(inputs) losses = dict() if self.with_head: losses.update( self.head.loss(feats, data_samples, train_cfg=self.train_cfg)) return losses
[文档] def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. Args: inputs (Tensor): Inputs with shape (N, C, H, W) data_samples (List[:obj:`PoseDataSample`]): The batch data samples Returns: list[:obj:`PoseDataSample`]: The pose estimation results of the input images. The return value is `PoseDataSample` instances with ``pred_instances`` and ``pred_fields``(optional) field , and ``pred_instances`` usually contains the following keys: - keypoints (Tensor): predicted keypoint coordinates in shape (num_instances, K, D) where K is the keypoint number and D is the keypoint dimension - keypoint_scores (Tensor): predicted keypoint scores in shape (num_instances, K) """ assert self.with_head, ( 'The model must have head to perform prediction.') if self.test_cfg.get('flip_test', False): _feats = self.extract_feat(inputs) _feats_flip = self.extract_feat(inputs.flip(-1)) feats = [_feats, _feats_flip] else: feats = self.extract_feat(inputs) preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg) if isinstance(preds, tuple): batch_pred_instances, batch_pred_fields = preds else: batch_pred_instances = preds batch_pred_fields = None results = self.add_pred_to_datasample(batch_pred_instances, batch_pred_fields, data_samples) return results
[文档] def add_pred_to_datasample(self, batch_pred_instances: InstanceList, batch_pred_fields: Optional[PixelDataList], batch_data_samples: SampleList) -> SampleList: """Add predictions into data samples. Args: batch_pred_instances (List[InstanceData]): The predicted instances of the input data batch batch_pred_fields (List[PixelData], optional): The predicted fields (e.g. heatmaps) of the input batch batch_data_samples (List[PoseDataSample]): The input data batch Returns: List[PoseDataSample]: A list of data samples where the predictions are stored in the ``pred_instances`` field of each data sample. """ assert len(batch_pred_instances) == len(batch_data_samples) if batch_pred_fields is None: batch_pred_fields = [] output_keypoint_indices = self.test_cfg.get('output_keypoint_indices', None) for pred_instances, pred_fields, data_sample in zip_longest( batch_pred_instances, batch_pred_fields, batch_data_samples): gt_instances = data_sample.gt_instances # convert keypoint coordinates from input space to image space bbox_centers = gt_instances.bbox_centers bbox_scales = gt_instances.bbox_scales input_size = data_sample.metainfo['input_size'] pred_instances.keypoints = pred_instances.keypoints / input_size \ * bbox_scales + bbox_centers - 0.5 * bbox_scales if output_keypoint_indices is not None: # select output keypoints with given indices num_keypoints = pred_instances.keypoints.shape[1] for key, value in pred_instances.all_items(): if key.startswith('keypoint'): pred_instances.set_field( value[:, output_keypoint_indices], key) # add bbox information into pred_instances pred_instances.bboxes = gt_instances.bboxes pred_instances.bbox_scores = gt_instances.bbox_scores data_sample.pred_instances = pred_instances if pred_fields is not None: if output_keypoint_indices is not None: # select output heatmap channels with keypoint indices # when the number of heatmap channel matches num_keypoints for key, value in pred_fields.all_items(): if value.shape[0] != num_keypoints: continue pred_fields.set_field(value[output_keypoint_indices], key) data_sample.pred_fields = pred_fields return batch_data_samples
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.