mmpose.models.utils.check_and_update_config 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union
from mmengine.config import Config, ConfigDict
from mmengine.dist import master_only
from mmengine.logging import MMLogger
ConfigType = Union[Config, ConfigDict]
def process_input_transform(input_transform: str, head: Dict, head_new: Dict,
head_deleted_dict: Dict, head_append_dict: Dict,
neck_new: Dict, input_index: Tuple[int],
align_corners: bool) -> None:
"""Process the input_transform field and update head and neck
dictionaries."""
if input_transform == 'resize_concat':
in_channels = head_new.pop('in_channels')
head_deleted_dict['in_channels'] = str(in_channels)
in_channels = sum([in_channels[i] for i in input_index])
head_new['in_channels'] = in_channels
head_append_dict['in_channels'] = str(in_channels)
neck_new.update(
dict(
type='FeatureMapProcessor',
concat=True,
select_index=input_index,
))
if align_corners:
neck_new['align_corners'] = align_corners
elif input_transform == 'select':
if input_index != (-1, ):
neck_new.update(
dict(type='FeatureMapProcessor', select_index=input_index))
if isinstance(head['in_channels'], tuple):
in_channels = head_new.pop('in_channels')
head_deleted_dict['in_channels'] = str(in_channels)
if isinstance(input_index, int):
in_channels = in_channels[input_index]
else:
in_channels = tuple([in_channels[i] for i in input_index])
head_new['in_channels'] = in_channels
head_append_dict['in_channels'] = str(in_channels)
if align_corners:
neck_new['align_corners'] = align_corners
else:
raise ValueError(f'model.head get invalid value for argument '
f'input_transform: {input_transform}')
def process_extra_field(extra: Dict, head_new: Dict, head_deleted_dict: Dict,
head_append_dict: Dict, neck_new: Dict) -> None:
"""Process the extra field and update head and neck dictionaries."""
head_deleted_dict['extra'] = 'dict('
for key, value in extra.items():
head_deleted_dict['extra'] += f'{key}={value},'
head_deleted_dict['extra'] = head_deleted_dict['extra'][:-1] + ')'
if 'final_conv_kernel' in extra:
kernel_size = extra['final_conv_kernel']
if kernel_size > 1:
padding = kernel_size // 2
head_new['final_layer'] = dict(
kernel_size=kernel_size, padding=padding)
head_append_dict[
'final_layer'] = f'dict(kernel_size={kernel_size}, ' \
f'padding={padding})'
else:
head_new['final_layer'] = dict(kernel_size=kernel_size)
head_append_dict[
'final_layer'] = f'dict(kernel_size={kernel_size})'
if 'upsample' in extra:
neck_new.update(
dict(
type='FeatureMapProcessor',
scale_factor=float(extra['upsample']),
apply_relu=True,
))
def process_has_final_layer(has_final_layer: bool, head_new: Dict,
head_deleted_dict: Dict,
head_append_dict: Dict) -> None:
"""Process the has_final_layer field and update the head dictionary."""
head_deleted_dict['has_final_layer'] = str(has_final_layer)
if not has_final_layer:
if 'final_layer' not in head_new:
head_new['final_layer'] = None
head_append_dict['final_layer'] = 'None'
[文档]def check_and_update_config(neck: Optional[ConfigType],
head: ConfigType) -> Tuple[Optional[Dict], Dict]:
"""Check and update the configuration of the head and neck components.
Args:
neck (Optional[ConfigType]): Configuration for the neck component.
head (ConfigType): Configuration for the head component.
Returns:
Tuple[Optional[Dict], Dict]: Updated configurations for the neck
and head components.
"""
head_new, neck_new = head.copy(), neck.copy() if isinstance(neck,
dict) else {}
head_deleted_dict, head_append_dict = {}, {}
if 'input_transform' in head:
input_transform = head_new.pop('input_transform')
head_deleted_dict['input_transform'] = f'\'{input_transform}\''
else:
input_transform = 'select'
if 'input_index' in head:
input_index = head_new.pop('input_index')
head_deleted_dict['input_index'] = str(input_index)
else:
input_index = (-1, )
if 'align_corners' in head:
align_corners = head_new.pop('align_corners')
head_deleted_dict['align_corners'] = str(align_corners)
else:
align_corners = False
process_input_transform(input_transform, head, head_new, head_deleted_dict,
head_append_dict, neck_new, input_index,
align_corners)
if 'extra' in head:
extra = head_new.pop('extra')
process_extra_field(extra, head_new, head_deleted_dict,
head_append_dict, neck_new)
if 'has_final_layer' in head:
has_final_layer = head_new.pop('has_final_layer')
process_has_final_layer(has_final_layer, head_new, head_deleted_dict,
head_append_dict)
display_modifications(head_deleted_dict, head_append_dict, neck_new)
neck_new = neck_new if len(neck_new) else None
return neck_new, head_new
@master_only
def display_modifications(head_deleted_dict: Dict, head_append_dict: Dict,
neck: Dict) -> None:
"""Display the modifications made to the head and neck configurations.
Args:
head_deleted_dict (Dict): Dictionary of deleted fields in the head.
head_append_dict (Dict): Dictionary of appended fields in the head.
neck (Dict): Updated neck configuration.
"""
if len(head_deleted_dict) + len(head_append_dict) == 0:
return
old_model_info, new_model_info = build_model_info(head_deleted_dict,
head_append_dict, neck)
total_info = '\nThe config you are using is outdated. '\
'The following section of the config:\n```\n'
total_info += old_model_info
total_info += '```\nshould be updated to\n```\n'
total_info += new_model_info
total_info += '```\nFor more information, please refer to '\
'https://mmpose.readthedocs.io/en/latest/' \
'guide_to_framework.html#step3-model'
logger: MMLogger = MMLogger.get_current_instance()
logger.warning(total_info)
def build_model_info(head_deleted_dict: Dict, head_append_dict: Dict,
neck: Dict) -> Tuple[str, str]:
"""Build the old and new model information strings.
Args:
head_deleted_dict (Dict): Dictionary of deleted fields in the head.
head_append_dict (Dict): Dictionary of appended fields in the head.
neck (Dict): Updated neck configuration.
Returns:
Tuple[str, str]: Old and new model information strings.
"""
old_head_info = build_head_info(head_deleted_dict)
new_head_info = build_head_info(head_append_dict)
neck_info = build_neck_info(neck)
old_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' + old_head_info
new_model_info = 'model=dict(\n' + ' ' * 4 + '...,\n' \
+ neck_info + new_head_info
return old_model_info, new_model_info
def build_head_info(head_dict: Dict) -> str:
"""Build the head information string.
Args:
head_dict (Dict): Dictionary of fields in the head configuration.
Returns:
str: Head information string.
"""
head_info = ' ' * 4 + 'head=dict(\n'
for key, value in head_dict.items():
head_info += ' ' * 8 + f'{key}={value},\n'
head_info += ' ' * 8 + '...),\n'
return head_info
def build_neck_info(neck: Dict) -> str:
"""Build the neck information string.
Args:
neck (Dict): Updated neck configuration.
Returns:
str: Neck information string.
"""
if len(neck) > 0:
neck = neck.copy()
neck_info = ' ' * 4 + 'neck=dict(\n' + ' ' * 8 + \
f'type=\'{neck.pop("type")}\',\n'
for key, value in neck.items():
neck_info += ' ' * 8 + f'{key}={str(value)},\n'
neck_info += ' ' * 4 + '),\n'
else:
neck_info = ''
return neck_info