Shortcuts

mmpose.models.backbones.tcn 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy

import torch.nn as nn
from mmcv.cnn import ConvModule, build_conv_layer
from mmengine.model import BaseModule

from mmpose.registry import MODELS
from ..utils.regularizations import WeightNormClipHook
from .base_backbone import BaseBackbone


class BasicTemporalBlock(BaseModule):
    """Basic block for VideoPose3D.

    Args:
        in_channels (int): Input channels of this block.
        out_channels (int): Output channels of this block.
        mid_channels (int): The output channels of conv1. Default: 1024.
        kernel_size (int): Size of the convolving kernel. Default: 3.
        dilation (int): Spacing between kernel elements. Default: 3.
        dropout (float): Dropout rate. Default: 0.25.
        causal (bool): Use causal convolutions instead of symmetric
            convolutions (for real-time applications). Default: False.
        residual (bool): Use residual connection. Default: True.
        use_stride_conv (bool): Use optimized TCN that designed
            specifically for single-frame batching, i.e. where batches have
            input length = receptive field, and output length = 1. This
            implementation replaces dilated convolutions with strided
            convolutions to avoid generating unused intermediate results.
            Default: False.
        conv_cfg (dict): dictionary to construct and config conv layer.
            Default: dict(type='Conv1d').
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default: dict(type='BN1d').
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels=1024,
                 kernel_size=3,
                 dilation=3,
                 dropout=0.25,
                 causal=False,
                 residual=True,
                 use_stride_conv=False,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 init_cfg=None):
        # Protect mutable default arguments
        conv_cfg = copy.deepcopy(conv_cfg)
        norm_cfg = copy.deepcopy(norm_cfg)
        super().__init__(init_cfg=init_cfg)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mid_channels = mid_channels
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.dropout = dropout
        self.causal = causal
        self.residual = residual
        self.use_stride_conv = use_stride_conv

        self.pad = (kernel_size - 1) * dilation // 2
        if use_stride_conv:
            self.stride = kernel_size
            self.causal_shift = kernel_size // 2 if causal else 0
            self.dilation = 1
        else:
            self.stride = 1
            self.causal_shift = kernel_size // 2 * dilation if causal else 0

        self.conv1 = nn.Sequential(
            ConvModule(
                in_channels,
                mid_channels,
                kernel_size=kernel_size,
                stride=self.stride,
                dilation=self.dilation,
                bias='auto',
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg))
        self.conv2 = nn.Sequential(
            ConvModule(
                mid_channels,
                out_channels,
                kernel_size=1,
                bias='auto',
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg))

        if residual and in_channels != out_channels:
            self.short_cut = build_conv_layer(conv_cfg, in_channels,
                                              out_channels, 1)
        else:
            self.short_cut = None

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self, x):
        """Forward function."""
        if self.use_stride_conv:
            assert self.causal_shift + self.kernel_size // 2 < x.shape[2]
        else:
            assert 0 <= self.pad + self.causal_shift < x.shape[2] - \
                self.pad + self.causal_shift <= x.shape[2]

        out = self.conv1(x)
        if self.dropout is not None:
            out = self.dropout(out)

        out = self.conv2(out)
        if self.dropout is not None:
            out = self.dropout(out)

        if self.residual:
            if self.use_stride_conv:
                res = x[:, :, self.causal_shift +
                        self.kernel_size // 2::self.kernel_size]
            else:
                res = x[:, :,
                        (self.pad + self.causal_shift):(x.shape[2] - self.pad +
                                                        self.causal_shift)]

            if self.short_cut is not None:
                res = self.short_cut(res)
            out = out + res

        return out


[文档]@MODELS.register_module() class TCN(BaseBackbone): """TCN backbone. Temporal Convolutional Networks. More details can be found in the `paper <https://arxiv.org/abs/1811.11742>`__ . Args: in_channels (int): Number of input channels, which equals to num_keypoints * num_features. stem_channels (int): Number of feature channels. Default: 1024. num_blocks (int): NUmber of basic temporal convolutional blocks. Default: 2. kernel_sizes (Sequence[int]): Sizes of the convolving kernel of each basic block. Default: ``(3, 3, 3)``. dropout (float): Dropout rate. Default: 0.25. causal (bool): Use causal convolutions instead of symmetric convolutions (for real-time applications). Default: False. residual (bool): Use residual connection. Default: True. use_stride_conv (bool): Use TCN backbone optimized for single-frame batching, i.e. where batches have input length = receptive field, and output length = 1. This implementation replaces dilated convolutions with strided convolutions to avoid generating unused intermediate results. The weights are interchangeable with the reference implementation. Default: False conv_cfg (dict): dictionary to construct and config conv layer. Default: dict(type='Conv1d'). norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN1d'). max_norm (float|None): if not None, the weight of convolution layers will be clipped to have a maximum norm of max_norm. init_cfg (dict or list[dict], optional): Initialization config dict. Default: ``[ dict( type='Kaiming', mode='fan_in', nonlinearity='relu', layer=['Conv2d']), dict( type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ]`` Example: >>> from mmpose.models import TCN >>> import torch >>> self = TCN(in_channels=34) >>> self.eval() >>> inputs = torch.rand(1, 34, 243) >>> level_outputs = self.forward(inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 1024, 235) (1, 1024, 217) """ def __init__(self, in_channels, stem_channels=1024, num_blocks=2, kernel_sizes=(3, 3, 3), dropout=0.25, causal=False, residual=True, use_stride_conv=False, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), max_norm=None, init_cfg=[ dict( type='Kaiming', mode='fan_in', nonlinearity='relu', layer=['Conv2d']), dict( type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) ]): # Protect mutable default arguments conv_cfg = copy.deepcopy(conv_cfg) norm_cfg = copy.deepcopy(norm_cfg) super().__init__() self.in_channels = in_channels self.stem_channels = stem_channels self.num_blocks = num_blocks self.kernel_sizes = kernel_sizes self.dropout = dropout self.causal = causal self.residual = residual self.use_stride_conv = use_stride_conv self.max_norm = max_norm assert num_blocks == len(kernel_sizes) - 1 for ks in kernel_sizes: assert ks % 2 == 1, 'Only odd filter widths are supported.' self.expand_conv = ConvModule( in_channels, stem_channels, kernel_size=kernel_sizes[0], stride=kernel_sizes[0] if use_stride_conv else 1, bias='auto', conv_cfg=conv_cfg, norm_cfg=norm_cfg) dilation = kernel_sizes[0] self.tcn_blocks = nn.ModuleList() for i in range(1, num_blocks + 1): self.tcn_blocks.append( BasicTemporalBlock( in_channels=stem_channels, out_channels=stem_channels, mid_channels=stem_channels, kernel_size=kernel_sizes[i], dilation=dilation, dropout=dropout, causal=causal, residual=residual, use_stride_conv=use_stride_conv, conv_cfg=conv_cfg, norm_cfg=norm_cfg)) dilation *= kernel_sizes[i] if self.max_norm is not None: # Apply weight norm clip to conv layers weight_clip = WeightNormClipHook(self.max_norm) for module in self.modules(): if isinstance(module, nn.modules.conv._ConvNd): weight_clip.register(module) self.dropout = nn.Dropout(dropout) if dropout > 0 else None
[文档] def forward(self, x): """Forward function.""" x = self.expand_conv(x) if self.dropout is not None: x = self.dropout(x) outs = [] for i in range(self.num_blocks): x = self.tcn_blocks[i](x) outs.append(x) return tuple(outs)
Read the Docs v: latest
Versions
latest
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.