mmpose.models.backbones.dstformer 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, constant_init
from mmengine.model.weight_init import trunc_normal_
from mmpose.registry import MODELS
from .base_backbone import BaseBackbone
class Attention(BaseModule):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
mode='spatial'):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.mode = mode
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_count_s = None
self.attn_count_t = None
def forward(self, x, seq_len=1):
B, N, C = x.shape
if self.mode == 'temporal':
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
x = self.forward_temporal(q, k, v, seq_len=seq_len)
elif self.mode == 'spatial':
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
x = self.forward_spatial(q, k, v)
else:
raise NotImplementedError(self.mode)
x = self.proj(x)
x = self.proj_drop(x)
return x
def forward_spatial(self, q, k, v):
B, _, N, C = q.shape
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C * self.num_heads)
return x
def forward_temporal(self, q, k, v, seq_len=8):
B, _, N, C = q.shape
qt = q.reshape(-1, seq_len, self.num_heads, N,
C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C)
kt = k.reshape(-1, seq_len, self.num_heads, N,
C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C)
vt = v.reshape(-1, seq_len, self.num_heads, N,
C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C)
attn = (qt @ kt.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ vt # (B, H, N, T, C)
x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C * self.num_heads)
return x
class AttentionBlock(BaseModule):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
mlp_out_ratio=1.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
st_mode='st'):
super().__init__()
self.st_mode = st_mode
self.norm1_s = nn.LayerNorm(dim, eps=1e-06)
self.norm1_t = nn.LayerNorm(dim, eps=1e-06)
self.attn_s = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
mode='spatial')
self.attn_t = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
mode='temporal')
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2_s = nn.LayerNorm(dim, eps=1e-06)
self.norm2_t = nn.LayerNorm(dim, eps=1e-06)
mlp_hidden_dim = int(dim * mlp_ratio)
mlp_out_dim = int(dim * mlp_out_ratio)
self.mlp_s = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim), nn.GELU(),
nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop))
self.mlp_t = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim), nn.GELU(),
nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop))
def forward(self, x, seq_len=1):
if self.st_mode == 'st':
x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len))
x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len))
x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
elif self.st_mode == 'ts':
x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len))
x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len))
x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
else:
raise NotImplementedError(self.st_mode)
return x
[文档]@MODELS.register_module()
class DSTFormer(BaseBackbone):
"""Dual-stream Spatio-temporal Transformer Module.
Args:
in_channels (int): Number of input channels.
feat_size: Number of feature channels. Default: 256.
depth: The network depth. Default: 5.
num_heads: Number of heads in multi-Head self-attention blocks.
Default: 8.
mlp_ratio (int, optional): The expansion ratio of FFN. Default: 4.
num_keypoints: num_keypoints (int): Number of keypoints. Default: 17.
seq_len: The sequence length. Default: 243.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout ratio of input. Default: 0.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
att_fuse: Whether to fuse the results of attention blocks.
Default: True.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Example:
>>> from mmpose.models import DSTFormer
>>> import torch
>>> self = DSTFormer(in_channels=3)
>>> self.eval()
>>> inputs = torch.rand(1, 2, 17, 3)
>>> level_outputs = self.forward(inputs)
>>> print(tuple(level_outputs.shape))
(1, 2, 17, 512)
"""
def __init__(self,
in_channels,
feat_size=256,
depth=5,
num_heads=8,
mlp_ratio=4,
num_keypoints=17,
seq_len=243,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
att_fuse=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.feat_size = feat_size
self.joints_embed = nn.Linear(in_channels, feat_size)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks_st = nn.ModuleList([
AttentionBlock(
dim=feat_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
st_mode='st') for i in range(depth)
])
self.blocks_ts = nn.ModuleList([
AttentionBlock(
dim=feat_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
st_mode='ts') for i in range(depth)
])
self.norm = nn.LayerNorm(feat_size, eps=1e-06)
self.temp_embed = nn.Parameter(torch.zeros(1, seq_len, 1, feat_size))
self.spat_embed = nn.Parameter(
torch.zeros(1, num_keypoints, feat_size))
trunc_normal_(self.temp_embed, std=.02)
trunc_normal_(self.spat_embed, std=.02)
self.att_fuse = att_fuse
if self.att_fuse:
self.attn_regress = nn.ModuleList(
[nn.Linear(feat_size * 2, 2) for i in range(depth)])
for i in range(depth):
self.attn_regress[i].weight.data.fill_(0)
self.attn_regress[i].bias.data.fill_(0.5)
[文档] def forward(self, x):
if len(x.shape) == 3:
x = x[None, :]
assert len(x.shape) == 4
B, F, K, C = x.shape
x = x.reshape(-1, K, C)
BF = x.shape[0]
x = self.joints_embed(x) # (BF, K, feat_size)
x = x + self.spat_embed
_, K, C = x.shape
x = x.reshape(-1, F, K, C) + self.temp_embed[:, :F, :, :]
x = x.reshape(BF, K, C) # (BF, K, feat_size)
x = self.pos_drop(x)
for idx, (blk_st,
blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)):
x_st = blk_st(x, F)
x_ts = blk_ts(x, F)
if self.att_fuse:
att = self.attn_regress[idx]
alpha = torch.cat([x_st, x_ts], dim=-1)
BF, K = alpha.shape[:2]
alpha = att(alpha)
alpha = alpha.softmax(dim=-1)
x = x_st * alpha[:, :, 0:1] + x_ts * alpha[:, :, 1:2]
else:
x = (x_st + x_ts) * 0.5
x = self.norm(x) # (BF, K, feat_size)
x = x.reshape(B, F, K, -1)
return x
[文档] def init_weights(self):
"""Initialize the weights in backbone."""
super(DSTFormer, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
return
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
constant_init(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)