Shortcuts

Source code for mmpose.datasets.datasets.base.base_coco_style_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from copy import deepcopy
from itertools import chain, filterfalse, groupby
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
from mmengine.dataset import BaseDataset, force_full_init
from mmengine.fileio import exists, get_local_path, load
from mmengine.logging import MessageHub
from mmengine.utils import is_list_of
from xtcocotools.coco import COCO

from mmpose.registry import DATASETS
from mmpose.structures.bbox import bbox_xywh2xyxy
from ..utils import parse_pose_metainfo


[docs]@DATASETS.register_module() class BaseCocoStyleDataset(BaseDataset): """Base class for COCO-style datasets. Args: ann_file (str): Annotation file path. Default: ''. bbox_file (str, optional): Detection result file path. If ``bbox_file`` is set, detected bboxes loaded from this file will be used instead of ground-truth bboxes. This setting is only for evaluation, i.e., ignored when ``test_mode`` is ``False``. Default: ``None``. data_mode (str): Specifies the mode of data samples: ``'topdown'`` or ``'bottomup'``. In ``'topdown'`` mode, each data sample contains one instance; while in ``'bottomup'`` mode, each data sample contains all instances in a image. Default: ``'topdown'`` metainfo (dict, optional): Meta information for dataset, such as class information. Default: ``None``. data_root (str, optional): The root directory for ``data_prefix`` and ``ann_file``. Default: ``None``. data_prefix (dict, optional): Prefix for training data. Default: ``dict(img='')``. filter_cfg (dict, optional): Config for filter data. Default: `None`. indices (int or Sequence[int], optional): Support using first few data in annotation file to facilitate training/testing on a smaller dataset. Default: ``None`` which means using all ``data_infos``. serialize_data (bool, optional): Whether to hold memory using serialized objects, when enabled, data loader workers can use shared RAM from master process instead of making a copy. Default: ``True``. pipeline (list, optional): Processing pipeline. Default: []. test_mode (bool, optional): ``test_mode=True`` means in test phase. Default: ``False``. lazy_init (bool, optional): Whether to load annotation during instantiation. In some cases, such as visualization, only the meta information of the dataset is needed, which is not necessary to load annotation file. ``Basedataset`` can skip load annotations to save time by set ``lazy_init=False``. Default: ``False``. max_refetch (int, optional): If ``Basedataset.prepare_data`` get a None img. The maximum extra number of cycles to get a valid image. Default: 1000. sample_interval (int, optional): The sample interval of the dataset. Default: 1. """ METAINFO: dict = dict() def __init__(self, ann_file: str = '', bbox_file: Optional[str] = None, data_mode: str = 'topdown', metainfo: Optional[dict] = None, data_root: Optional[str] = None, data_prefix: dict = dict(img=''), filter_cfg: Optional[dict] = None, indices: Optional[Union[int, Sequence[int]]] = None, serialize_data: bool = True, pipeline: List[Union[dict, Callable]] = [], test_mode: bool = False, lazy_init: bool = False, max_refetch: int = 1000, sample_interval: int = 1): if data_mode not in {'topdown', 'bottomup'}: raise ValueError( f'{self.__class__.__name__} got invalid data_mode: ' f'{data_mode}. Should be "topdown" or "bottomup".') self.data_mode = data_mode if bbox_file: if self.data_mode != 'topdown': raise ValueError( f'{self.__class__.__name__} is set to {self.data_mode}: ' 'mode, while "bbox_file" is only ' 'supported in topdown mode.') if not test_mode: raise ValueError( f'{self.__class__.__name__} has `test_mode==False` ' 'while "bbox_file" is only ' 'supported when `test_mode==True`.') self.bbox_file = bbox_file self.sample_interval = sample_interval super().__init__( ann_file=ann_file, metainfo=metainfo, data_root=data_root, data_prefix=data_prefix, filter_cfg=filter_cfg, indices=indices, serialize_data=serialize_data, pipeline=pipeline, test_mode=test_mode, lazy_init=lazy_init, max_refetch=max_refetch) if self.test_mode: # save the ann_file into MessageHub for CocoMetric message = MessageHub.get_current_instance() dataset_name = self.metainfo['dataset_name'] message.update_info_dict( {f'{dataset_name}_ann_file': self.ann_file}) @classmethod def _load_metainfo(cls, metainfo: dict = None) -> dict: """Collect meta information from the dictionary of meta. Args: metainfo (dict): Raw data of pose meta information. Returns: dict: Parsed meta information. """ if metainfo is None: metainfo = deepcopy(cls.METAINFO) if not isinstance(metainfo, dict): raise TypeError( f'metainfo should be a dict, but got {type(metainfo)}') # parse pose metainfo if it has been assigned if metainfo: metainfo = parse_pose_metainfo(metainfo) return metainfo
[docs] @force_full_init def prepare_data(self, idx) -> Any: """Get data processed by ``self.pipeline``. :class:`BaseCocoStyleDataset` overrides this method from :class:`mmengine.dataset.BaseDataset` to add the metainfo into the ``data_info`` before it is passed to the pipeline. Args: idx (int): The index of ``data_info``. Returns: Any: Depends on ``self.pipeline``. """ data_info = self.get_data_info(idx) # Mixed image transformations require multiple source images for # effective blending. Therefore, we assign the 'dataset' field in # `data_info` to provide these auxiliary images. # Note: The 'dataset' assignment should not occur within the # `get_data_info` function, as doing so may cause the mixed image # transformations to stall or hang. data_info['dataset'] = self return self.pipeline(data_info)
[docs] def get_data_info(self, idx: int) -> dict: """Get data info by index. Args: idx (int): Index of data info. Returns: dict: Data info. """ data_info = super().get_data_info(idx) # Add metainfo items that are required in the pipeline and the model metainfo_keys = [ 'dataset_name', 'upper_body_ids', 'lower_body_ids', 'flip_pairs', 'dataset_keypoint_weights', 'flip_indices', 'skeleton_links' ] for key in metainfo_keys: assert key not in data_info, ( f'"{key}" is a reserved key for `metainfo`, but already ' 'exists in the `data_info`.') data_info[key] = deepcopy(self._metainfo[key]) return data_info
[docs] def load_data_list(self) -> List[dict]: """Load data list from COCO annotation file or person detection result file.""" if self.bbox_file: data_list = self._load_detection_results() else: instance_list, image_list = self._load_annotations() if self.data_mode == 'topdown': data_list = self._get_topdown_data_infos(instance_list) else: data_list = self._get_bottomup_data_infos( instance_list, image_list) if hasattr(self, 'coco'): del self.coco return data_list
def _load_annotations(self) -> Tuple[List[dict], List[dict]]: """Load data from annotations in COCO format.""" assert exists(self.ann_file), ( f'Annotation file `{self.ann_file}`does not exist') with get_local_path(self.ann_file) as local_path: self.coco = COCO(local_path) # set the metainfo about categories, which is a list of dict # and each dict contains the 'id', 'name', etc. about this category if 'categories' in self.coco.dataset: self._metainfo['CLASSES'] = self.coco.loadCats( self.coco.getCatIds()) instance_list = [] image_list = [] for img_id in self.coco.getImgIds(): if img_id % self.sample_interval != 0: continue img = self.coco.loadImgs(img_id)[0] img.update({ 'img_id': img_id, 'img_path': osp.join(self.data_prefix['img'], img['file_name']), }) image_list.append(img) ann_ids = self.coco.getAnnIds(imgIds=img_id) for ann in self.coco.loadAnns(ann_ids): instance_info = self.parse_data_info( dict(raw_ann_info=ann, raw_img_info=img)) # skip invalid instance annotation. if not instance_info: continue instance_list.append(instance_info) return instance_list, image_list
[docs] def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: """Parse raw COCO annotation of an instance. Args: raw_data_info (dict): Raw data information loaded from ``ann_file``. It should have following contents: - ``'raw_ann_info'``: Raw annotation of an instance - ``'raw_img_info'``: Raw information of the image that contains the instance Returns: dict | None: Parsed instance annotation """ ann = raw_data_info['raw_ann_info'] img = raw_data_info['raw_img_info'] # filter invalid instance if 'bbox' not in ann or 'keypoints' not in ann: return None img_w, img_h = img['width'], img['height'] # get bbox in shape [1, 4], formatted as xywh x, y, w, h = ann['bbox'] x1 = np.clip(x, 0, img_w - 1) y1 = np.clip(y, 0, img_h - 1) x2 = np.clip(x + w, 0, img_w - 1) y2 = np.clip(y + h, 0, img_h - 1) bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] _keypoints = np.array( ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) keypoints = _keypoints[..., :2] keypoints_visible = np.minimum(1, _keypoints[..., 2]) if 'num_keypoints' in ann: num_keypoints = ann['num_keypoints'] else: num_keypoints = np.count_nonzero(keypoints.max(axis=2)) if 'area' in ann: area = np.array(ann['area'], dtype=np.float32) else: area = np.clip((x2 - x1) * (y2 - y1) * 0.53, a_min=1.0, a_max=None) area = np.array(area, dtype=np.float32) data_info = { 'img_id': ann['image_id'], 'img_path': img['img_path'], 'bbox': bbox, 'bbox_score': np.ones(1, dtype=np.float32), 'num_keypoints': num_keypoints, 'keypoints': keypoints, 'keypoints_visible': keypoints_visible, 'area': area, 'iscrowd': ann.get('iscrowd', 0), 'segmentation': ann.get('segmentation', None), 'id': ann['id'], 'category_id': np.array(ann['category_id']), # store the raw annotation of the instance # it is useful for evaluation without providing ann_file 'raw_ann_info': copy.deepcopy(ann), } if 'crowdIndex' in img: data_info['crowd_index'] = img['crowdIndex'] return data_info
@staticmethod def _is_valid_instance(data_info: Dict) -> bool: """Check a data info is an instance with valid bbox and keypoint annotations.""" # crowd annotation if 'iscrowd' in data_info and data_info['iscrowd']: return False # invalid keypoints if 'num_keypoints' in data_info and data_info['num_keypoints'] == 0: return False # invalid bbox if 'bbox' in data_info: bbox = data_info['bbox'][0] w, h = bbox[2:4] - bbox[:2] if w <= 0 or h <= 0: return False # invalid keypoints if 'keypoints' in data_info: if np.max(data_info['keypoints']) <= 0: return False return True def _get_topdown_data_infos(self, instance_list: List[Dict]) -> List[Dict]: """Organize the data list in top-down mode.""" # sanitize data samples data_list_tp = list(filter(self._is_valid_instance, instance_list)) return data_list_tp def _get_bottomup_data_infos(self, instance_list: List[Dict], image_list: List[Dict]) -> List[Dict]: """Organize the data list in bottom-up mode.""" # bottom-up data list data_list_bu = [] used_img_ids = set() # group instances by img_id for img_id, data_infos in groupby(instance_list, lambda x: x['img_id']): used_img_ids.add(img_id) data_infos = list(data_infos) # image data img_path = data_infos[0]['img_path'] data_info_bu = { 'img_id': img_id, 'img_path': img_path, } for key in data_infos[0].keys(): if key not in data_info_bu: seq = [d[key] for d in data_infos] if isinstance(seq[0], np.ndarray): if seq[0].ndim > 0: seq = np.concatenate(seq, axis=0) else: seq = np.stack(seq, axis=0) elif isinstance(seq[0], (tuple, list)): seq = list(chain.from_iterable(seq)) data_info_bu[key] = seq # The segmentation annotation of invalid objects will be used # to generate valid region mask in the pipeline. invalid_segs = [] for data_info_invalid in filterfalse(self._is_valid_instance, data_infos): if 'segmentation' in data_info_invalid: invalid_segs.append(data_info_invalid['segmentation']) data_info_bu['invalid_segs'] = invalid_segs data_list_bu.append(data_info_bu) # add images without instance for evaluation if self.test_mode: for img_info in image_list: if img_info['img_id'] not in used_img_ids: data_info_bu = { 'img_id': img_info['img_id'], 'img_path': img_info['img_path'], 'id': list(), 'raw_ann_info': None, } data_list_bu.append(data_info_bu) return data_list_bu def _load_detection_results(self) -> List[dict]: """Load data from detection results with dummy keypoint annotations.""" assert exists(self.ann_file), ( f'Annotation file `{self.ann_file}` does not exist') assert exists( self.bbox_file), (f'Bbox file `{self.bbox_file}` does not exist') # load detection results det_results = load(self.bbox_file) assert is_list_of( det_results, dict), (f'BBox file `{self.bbox_file}` should be a list of dict, ' f'but got {type(det_results)}') # load coco annotations to build image id-to-name index with get_local_path(self.ann_file) as local_path: self.coco = COCO(local_path) # set the metainfo about categories, which is a list of dict # and each dict contains the 'id', 'name', etc. about this category self._metainfo['CLASSES'] = self.coco.loadCats(self.coco.getCatIds()) num_keypoints = self.metainfo['num_keypoints'] data_list = [] id_ = 0 for det in det_results: # remove non-human instances if det['category_id'] != 1: continue img = self.coco.loadImgs(det['image_id'])[0] img_path = osp.join(self.data_prefix['img'], img['file_name']) bbox_xywh = np.array( det['bbox'][:4], dtype=np.float32).reshape(1, 4) bbox = bbox_xywh2xyxy(bbox_xywh) bbox_score = np.array(det['score'], dtype=np.float32).reshape(1) # use dummy keypoint location and visibility keypoints = np.zeros((1, num_keypoints, 2), dtype=np.float32) keypoints_visible = np.ones((1, num_keypoints), dtype=np.float32) data_list.append({ 'img_id': det['image_id'], 'img_path': img_path, 'img_shape': (img['height'], img['width']), 'bbox': bbox, 'bbox_score': bbox_score, 'keypoints': keypoints, 'keypoints_visible': keypoints_visible, 'id': id_, }) id_ += 1 return data_list
[docs] def filter_data(self) -> List[dict]: """Filter annotations according to filter_cfg. Defaults return full ``data_list``. If 'bbox_score_thr` in filter_cfg, the annotation with bbox_score below the threshold `bbox_score_thr` will be filtered out. """ data_list = self.data_list if self.filter_cfg is None: return data_list # filter out annotations with a bbox_score below the threshold if 'bbox_score_thr' in self.filter_cfg: if self.data_mode != 'topdown': raise ValueError( f'{self.__class__.__name__} is set to {self.data_mode} ' 'mode, while "bbox_score_thr" is only supported in ' 'topdown mode.') thr = self.filter_cfg['bbox_score_thr'] data_list = list( filterfalse(lambda ann: ann['bbox_score'] < thr, data_list)) return data_list