Shortcuts

mmpose.models.backbones.hourglass_ae 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy

import torch.nn as nn
from mmcv.cnn import ConvModule, MaxPool2d
from mmengine.model import BaseModule

from mmpose.registry import MODELS
from .base_backbone import BaseBackbone


class HourglassAEModule(BaseModule):
    """Modified Hourglass Module for HourglassNet_AE backbone.

    Generate module recursively and use BasicBlock as the base unit.

    Args:
        depth (int): Depth of current HourglassModule.
        stage_channels (list[int]): Feature channels of sub-modules in current
            and follow-up HourglassModule.
        norm_cfg (dict): Dictionary to construct and config norm layer.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None
    """

    def __init__(self,
                 depth,
                 stage_channels,
                 norm_cfg=dict(type='BN', requires_grad=True),
                 init_cfg=None):
        # Protect mutable default arguments
        norm_cfg = copy.deepcopy(norm_cfg)
        super().__init__(init_cfg=init_cfg)

        self.depth = depth

        cur_channel = stage_channels[0]
        next_channel = stage_channels[1]

        self.up1 = ConvModule(
            cur_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)

        self.pool1 = MaxPool2d(2, 2)

        self.low1 = ConvModule(
            cur_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)

        if self.depth > 1:
            self.low2 = HourglassAEModule(depth - 1, stage_channels[1:])
        else:
            self.low2 = ConvModule(
                next_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)

        self.low3 = ConvModule(
            next_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)

        self.up2 = nn.UpsamplingNearest2d(scale_factor=2)

    def forward(self, x):
        """Model forward function."""
        up1 = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2 = self.up2(low3)
        return up1 + up2


[文档]@MODELS.register_module() class HourglassAENet(BaseBackbone): """Hourglass-AE Network proposed by Newell et al. Associative Embedding: End-to-End Learning for Joint Detection and Grouping. More details can be found in the `paper <https://arxiv.org/abs/1611.05424>`__ . Args: downsample_times (int): Downsample times in a HourglassModule. num_stacks (int): Number of HourglassModule modules stacked, 1 for Hourglass-52, 2 for Hourglass-104. stage_channels (list[int]): Feature channel of each sub-module in a HourglassModule. stage_blocks (list[int]): Number of sub-modules stacked in a HourglassModule. feat_channels (int): Feature channel of conv after a HourglassModule. norm_cfg (dict): Dictionary to construct and config norm layer. init_cfg (dict or list[dict], optional): Initialization config dict. Default: ``[ dict(type='Normal', std=0.001, layer=['Conv2d']), dict( type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ]`` Example: >>> from mmpose.models import HourglassAENet >>> import torch >>> self = HourglassAENet() >>> self.eval() >>> inputs = torch.rand(1, 3, 512, 512) >>> level_outputs = self.forward(inputs) >>> for level_output in level_outputs: ... print(tuple(level_output.shape)) (1, 34, 128, 128) """ def __init__( self, downsample_times=4, num_stacks=1, out_channels=34, stage_channels=(256, 384, 512, 640, 768), feat_channels=256, norm_cfg=dict(type='BN', requires_grad=True), init_cfg=[ dict(type='Normal', std=0.001, layer=['Conv2d']), dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ], ): # Protect mutable default arguments norm_cfg = copy.deepcopy(norm_cfg) super().__init__(init_cfg=init_cfg) self.num_stacks = num_stacks assert self.num_stacks >= 1 assert len(stage_channels) > downsample_times cur_channels = stage_channels[0] self.stem = nn.Sequential( ConvModule(3, 64, 7, padding=3, stride=2, norm_cfg=norm_cfg), ConvModule(64, 128, 3, padding=1, norm_cfg=norm_cfg), MaxPool2d(2, 2), ConvModule(128, 128, 3, padding=1, norm_cfg=norm_cfg), ConvModule(128, feat_channels, 3, padding=1, norm_cfg=norm_cfg), ) self.hourglass_modules = nn.ModuleList([ nn.Sequential( HourglassAEModule( downsample_times, stage_channels, norm_cfg=norm_cfg), ConvModule( feat_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg), ConvModule( feat_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg)) for _ in range(num_stacks) ]) self.out_convs = nn.ModuleList([ ConvModule( cur_channels, out_channels, 1, padding=0, norm_cfg=None, act_cfg=None) for _ in range(num_stacks) ]) self.remap_out_convs = nn.ModuleList([ ConvModule( out_channels, feat_channels, 1, norm_cfg=norm_cfg, act_cfg=None) for _ in range(num_stacks - 1) ]) self.remap_feature_convs = nn.ModuleList([ ConvModule( feat_channels, feat_channels, 1, norm_cfg=norm_cfg, act_cfg=None) for _ in range(num_stacks - 1) ]) self.relu = nn.ReLU(inplace=True)
[文档] def forward(self, x): """Model forward function.""" inter_feat = self.stem(x) out_feats = [] for ind in range(self.num_stacks): single_hourglass = self.hourglass_modules[ind] out_conv = self.out_convs[ind] hourglass_feat = single_hourglass(inter_feat) out_feat = out_conv(hourglass_feat) out_feats.append(out_feat) if ind < self.num_stacks - 1: inter_feat = inter_feat + self.remap_out_convs[ind]( out_feat) + self.remap_feature_convs[ind]( hourglass_feat) return out_feats
Read the Docs v: latest
Versions
latest
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.