Shortcuts

Source code for mmpose.models.utils.check_and_update_config

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union

from mmengine.config import Config, ConfigDict
from mmengine.dist import master_only
from mmengine.logging import MMLogger

ConfigType = Union[Config, ConfigDict]


def process_input_transform(input_transform: str, head: Dict, head_new: Dict,
                            head_deleted_dict: Dict, head_append_dict: Dict,
                            neck_new: Dict, input_index: Tuple[int],
                            align_corners: bool) -> None:
    """Process the input_transform field and update head and neck
    dictionaries."""
    if input_transform == 'resize_concat':
        in_channels = head_new.pop('in_channels')
        head_deleted_dict['in_channels'] = str(in_channels)
        in_channels = sum([in_channels[i] for i in input_index])
        head_new['in_channels'] = in_channels
        head_append_dict['in_channels'] = str(in_channels)

        neck_new.update(
            dict(
                type='FeatureMapProcessor',
                concat=True,
                select_index=input_index,
            ))
        if align_corners:
            neck_new['align_corners'] = align_corners

    elif input_transform == 'select':
        if input_index != (-1, ):
            neck_new.update(
                dict(type='FeatureMapProcessor', select_index=input_index))
            if isinstance(head['in_channels'], tuple):
                in_channels = head_new.pop('in_channels')
                head_deleted_dict['in_channels'] = str(in_channels)
                if isinstance(input_index, int):
                    in_channels = in_channels[input_index]
                else:
                    in_channels = tuple([in_channels[i] for i in input_index])
                head_new['in_channels'] = in_channels
                head_append_dict['in_channels'] = str(in_channels)
            if align_corners:
                neck_new['align_corners'] = align_corners

    else:
        raise ValueError(f'model.head get invalid value for argument '
                         f'input_transform: {input_transform}')


def process_extra_field(extra: Dict, head_new: Dict, head_deleted_dict: Dict,
                        head_append_dict: Dict, neck_new: Dict) -> None:
    """Process the extra field and update head and neck dictionaries."""
    head_deleted_dict['extra'] = 'dict('
    for key, value in extra.items():
        head_deleted_dict['extra'] += f'{key}={value},'
    head_deleted_dict['extra'] = head_deleted_dict['extra'][:-1] + ')'
    if 'final_conv_kernel' in extra:
        kernel_size = extra['final_conv_kernel']
        if kernel_size > 1:
            padding = kernel_size // 2
            head_new['final_layer'] = dict(
                kernel_size=kernel_size, padding=padding)
            head_append_dict[
                'final_layer'] = f'dict(kernel_size={kernel_size}, ' \
                                 f'padding={padding})'
        else:
            head_new['final_layer'] = dict(kernel_size=kernel_size)
            head_append_dict[
                'final_layer'] = f'dict(kernel_size={kernel_size})'
    if 'upsample' in extra:
        neck_new.update(
            dict(
                type='FeatureMapProcessor',
                scale_factor=float(extra['upsample']),
                apply_relu=True,
            ))


def process_has_final_layer(has_final_layer: bool, head_new: Dict,
                            head_deleted_dict: Dict,
                            head_append_dict: Dict) -> None:
    """Process the has_final_layer field and update the head dictionary."""
    head_deleted_dict['has_final_layer'] = str(has_final_layer)
    if not has_final_layer:
        if 'final_layer' not in head_new:
            head_new['final_layer'] = None
        head_append_dict['final_layer'] = 'None'


[docs]def check_and_update_config(neck: Optional[ConfigType], head: ConfigType) -> Tuple[Optional[Dict], Dict]: """Check and update the configuration of the head and neck components. Args: neck (Optional[ConfigType]): Configuration for the neck component. head (ConfigType): Configuration for the head component. Returns: Tuple[Optional[Dict], Dict]: Updated configurations for the neck and head components. """ head_new, neck_new = head.copy(), neck.copy() if isinstance(neck, dict) else {} head_deleted_dict, head_append_dict = {}, {} if 'input_transform' in head: input_transform = head_new.pop('input_transform') head_deleted_dict['input_transform'] = f'\'{input_transform}\'' else: input_transform = 'select' if 'input_index' in head: input_index = head_new.pop('input_index') head_deleted_dict['input_index'] = str(input_index) else: input_index = (-1, ) if 'align_corners' in head: align_corners = head_new.pop('align_corners') head_deleted_dict['align_corners'] = str(align_corners) else: align_corners = False process_input_transform(input_transform, head, head_new, head_deleted_dict, head_append_dict, neck_new, input_index, align_corners) if 'extra' in head: extra = head_new.pop('extra') process_extra_field(extra, head_new, head_deleted_dict, head_append_dict, neck_new) if 'has_final_layer' in head: has_final_layer = head_new.pop('has_final_layer') process_has_final_layer(has_final_layer, head_new, head_deleted_dict, head_append_dict) display_modifications(head_deleted_dict, head_append_dict, neck_new) neck_new = neck_new if len(neck_new) else None return neck_new, head_new
@master_only def display_modifications(head_deleted_dict: Dict, head_append_dict: Dict, neck: Dict) -> None: """Display the modifications made to the head and neck configurations. Args: head_deleted_dict (Dict): Dictionary of deleted fields in the head. head_append_dict (Dict): Dictionary of appended fields in the head. neck (Dict): Updated neck configuration. """ if len(head_deleted_dict) + len(head_append_dict) == 0: return old_model_info, new_model_info = build_model_info(head_deleted_dict, head_append_dict, neck) total_info = '\nThe config you are using is outdated. '\ 'The following section of the config:\n```\n' total_info += old_model_info total_info += '```\nshould be updated to\n```\n' total_info += new_model_info total_info += '```\nFor more information, please refer to '\ 'https://mmpose.readthedocs.io/en/latest/' \ 'guide_to_framework.html#step3-model' logger: MMLogger = MMLogger.get_current_instance() logger.warning(total_info) def build_model_info(head_deleted_dict: Dict, head_append_dict: Dict, neck: Dict) -> Tuple[str, str]: """Build the old and new model information strings. Args: head_deleted_dict (Dict): Dictionary of deleted fields in the head. head_append_dict (Dict): Dictionary of appended fields in the head. neck (Dict): Updated neck configuration. Returns: Tuple[str, str]: Old and new model information strings. """ old_head_info = build_head_info(head_deleted_dict) new_head_info = build_head_info(head_append_dict) neck_info = build_neck_info(neck) old_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' + old_head_info new_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' \ + neck_info + new_head_info return old_model_info, new_model_info def build_head_info(head_dict: Dict) -> str: """Build the head information string. Args: head_dict (Dict): Dictionary of fields in the head configuration. Returns: str: Head information string. """ head_info = ' ' * 4 + 'head=dict(\n' for key, value in head_dict.items(): head_info += ' ' * 8 + f'{key}={value},\n' head_info += ' ' * 8 + '...),\n' return head_info def build_neck_info(neck: Dict) -> str: """Build the neck information string. Args: neck (Dict): Updated neck configuration. Returns: str: Neck information string. """ if len(neck) > 0: neck = neck.copy() neck_info = ' ' * 4 + 'neck=dict(\n' + ' ' * 8 + \ f'type=\'{neck.pop("type")}\',\n' for key, value in neck.items(): neck_info += ' ' * 8 + f'{key}={str(value)},\n' neck_info += ' ' * 4 + '),\n' else: neck_info = '' return neck_info
Read the Docs v: latest
Versions
latest
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.