Shortcuts

Source code for mmpose.models.backbones.dstformer

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, constant_init
from mmengine.model.weight_init import trunc_normal_

from mmpose.registry import MODELS
from .base_backbone import BaseBackbone


class Attention(BaseModule):

    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.,
                 mode='spatial'):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.mode = mode

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_count_s = None
        self.attn_count_t = None

    def forward(self, x, seq_len=1):
        B, N, C = x.shape

        if self.mode == 'temporal':
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
                                      self.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[
                2]  # make torchscript happy (cannot use tensor as tuple)
            x = self.forward_temporal(q, k, v, seq_len=seq_len)
        elif self.mode == 'spatial':
            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
                                      self.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[
                2]  # make torchscript happy (cannot use tensor as tuple)
            x = self.forward_spatial(q, k, v)
        else:
            raise NotImplementedError(self.mode)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def forward_spatial(self, q, k, v):
        B, _, N, C = q.shape
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = attn @ v
        x = x.transpose(1, 2).reshape(B, N, C * self.num_heads)
        return x

    def forward_temporal(self, q, k, v, seq_len=8):
        B, _, N, C = q.shape
        qt = q.reshape(-1, seq_len, self.num_heads, N,
                       C).permute(0, 2, 3, 1, 4)  # (B, H, N, T, C)
        kt = k.reshape(-1, seq_len, self.num_heads, N,
                       C).permute(0, 2, 3, 1, 4)  # (B, H, N, T, C)
        vt = v.reshape(-1, seq_len, self.num_heads, N,
                       C).permute(0, 2, 3, 1, 4)  # (B, H, N, T, C)

        attn = (qt @ kt.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = attn @ vt  # (B, H, N, T, C)
        x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C * self.num_heads)
        return x


class AttentionBlock(BaseModule):

    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 mlp_out_ratio=1.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 st_mode='st'):
        super().__init__()

        self.st_mode = st_mode
        self.norm1_s = nn.LayerNorm(dim, eps=1e-06)
        self.norm1_t = nn.LayerNorm(dim, eps=1e-06)

        self.attn_s = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            mode='spatial')
        self.attn_t = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            mode='temporal')

        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()
        self.norm2_s = nn.LayerNorm(dim, eps=1e-06)
        self.norm2_t = nn.LayerNorm(dim, eps=1e-06)

        mlp_hidden_dim = int(dim * mlp_ratio)
        mlp_out_dim = int(dim * mlp_out_ratio)
        self.mlp_s = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim), nn.GELU(),
            nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop))
        self.mlp_t = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim), nn.GELU(),
            nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop))

    def forward(self, x, seq_len=1):
        if self.st_mode == 'st':
            x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len))
            x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
            x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len))
            x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
        elif self.st_mode == 'ts':
            x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len))
            x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
            x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len))
            x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
        else:
            raise NotImplementedError(self.st_mode)
        return x


[docs]@MODELS.register_module() class DSTFormer(BaseBackbone): """Dual-stream Spatio-temporal Transformer Module. Args: in_channels (int): Number of input channels. feat_size: Number of feature channels. Default: 256. depth: The network depth. Default: 5. num_heads: Number of heads in multi-Head self-attention blocks. Default: 8. mlp_ratio (int, optional): The expansion ratio of FFN. Default: 4. num_keypoints: num_keypoints (int): Number of keypoints. Default: 17. seq_len: The sequence length. Default: 243. qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. Default: True. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. Default: None. drop_rate (float, optional): Dropout ratio of input. Default: 0. attn_drop_rate (float, optional): Dropout ratio of attention weight. Default: 0. drop_path_rate (float, optional): Stochastic depth rate. Default: 0. att_fuse: Whether to fuse the results of attention blocks. Default: True. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None Example: >>> from mmpose.models import DSTFormer >>> import torch >>> self = DSTFormer(in_channels=3) >>> self.eval() >>> inputs = torch.rand(1, 2, 17, 3) >>> level_outputs = self.forward(inputs) >>> print(tuple(level_outputs.shape)) (1, 2, 17, 512) """ def __init__(self, in_channels, feat_size=256, depth=5, num_heads=8, mlp_ratio=4, num_keypoints=17, seq_len=243, qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., att_fuse=True, init_cfg=None): super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.feat_size = feat_size self.joints_embed = nn.Linear(in_channels, feat_size) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule self.blocks_st = nn.ModuleList([ AttentionBlock( dim=feat_size, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], st_mode='st') for i in range(depth) ]) self.blocks_ts = nn.ModuleList([ AttentionBlock( dim=feat_size, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], st_mode='ts') for i in range(depth) ]) self.norm = nn.LayerNorm(feat_size, eps=1e-06) self.temp_embed = nn.Parameter(torch.zeros(1, seq_len, 1, feat_size)) self.spat_embed = nn.Parameter( torch.zeros(1, num_keypoints, feat_size)) trunc_normal_(self.temp_embed, std=.02) trunc_normal_(self.spat_embed, std=.02) self.att_fuse = att_fuse if self.att_fuse: self.attn_regress = nn.ModuleList( [nn.Linear(feat_size * 2, 2) for i in range(depth)]) for i in range(depth): self.attn_regress[i].weight.data.fill_(0) self.attn_regress[i].bias.data.fill_(0.5)
[docs] def forward(self, x): if len(x.shape) == 3: x = x[None, :] assert len(x.shape) == 4 B, F, K, C = x.shape x = x.reshape(-1, K, C) BF = x.shape[0] x = self.joints_embed(x) # (BF, K, feat_size) x = x + self.spat_embed _, K, C = x.shape x = x.reshape(-1, F, K, C) + self.temp_embed[:, :F, :, :] x = x.reshape(BF, K, C) # (BF, K, feat_size) x = self.pos_drop(x) for idx, (blk_st, blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)): x_st = blk_st(x, F) x_ts = blk_ts(x, F) if self.att_fuse: att = self.attn_regress[idx] alpha = torch.cat([x_st, x_ts], dim=-1) BF, K = alpha.shape[:2] alpha = att(alpha) alpha = alpha.softmax(dim=-1) x = x_st * alpha[:, :, 0:1] + x_ts * alpha[:, :, 1:2] else: x = (x_st + x_ts) * 0.5 x = self.norm(x) # (BF, K, feat_size) x = x.reshape(B, F, K, -1) return x
[docs] def init_weights(self): """Initialize the weights in backbone.""" super(DSTFormer, self).init_weights() if (isinstance(self.init_cfg, dict) and self.init_cfg['type'] == 'Pretrained'): return for m in self.modules(): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: constant_init(m.bias, 0) elif isinstance(m, nn.LayerNorm): constant_init(m.bias, 0) constant_init(m.weight, 1.0)
Read the Docs v: latest
Versions
latest
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.