mmpose.models.utils.rtmcc_block 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import DropPath
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from .transformer import ScaleNorm
[文档]def rope(x, dim):
"""Applies Rotary Position Embedding to input tensor.
Args:
x (torch.Tensor): Input tensor.
dim (int | list[int]): The spatial dimension(s) to apply
rotary position embedding.
Returns:
torch.Tensor: The tensor after applying rotary position
embedding.
Reference:
`RoFormer: Enhanced Transformer with Rotary
Position Embedding <https://arxiv.org/abs/2104.09864>`_
"""
shape = x.shape
if isinstance(dim, int):
dim = [dim]
spatial_shape = [shape[i] for i in dim]
total_len = 1
for i in spatial_shape:
total_len *= i
position = torch.reshape(
torch.arange(total_len, dtype=torch.int, device=x.device),
spatial_shape)
for i in range(dim[-1] + 1, len(shape) - 1, 1):
position = torch.unsqueeze(position, dim=-1)
half_size = shape[-1] // 2
freq_seq = -torch.arange(
half_size, dtype=torch.int, device=x.device) / float(half_size)
inv_freq = 10000**-freq_seq
sinusoid = position[..., None] * inv_freq[None, None, :]
sin = torch.sin(sinusoid)
cos = torch.cos(sinusoid)
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
class Scale(nn.Module):
"""Scale vector by element multiplications.
Args:
dim (int): The dimension of the scale vector.
init_value (float, optional): The initial value of the scale vector.
Defaults to 1.0.
trainable (bool, optional): Whether the scale vector is trainable.
Defaults to True.
"""
def __init__(self, dim, init_value=1., trainable=True):
super().__init__()
self.scale = nn.Parameter(
init_value * torch.ones(dim), requires_grad=trainable)
def forward(self, x):
"""Forward function."""
return x * self.scale
[文档]class RTMCCBlock(nn.Module):
"""Gated Attention Unit (GAU) in RTMBlock.
Args:
num_token (int): The number of tokens.
in_token_dims (int): The input token dimension.
out_token_dims (int): The output token dimension.
expansion_factor (int, optional): The expansion factor of the
intermediate token dimension. Defaults to 2.
s (int, optional): The self-attention feature dimension.
Defaults to 128.
eps (float, optional): The minimum value in clamp. Defaults to 1e-5.
dropout_rate (float, optional): The dropout rate. Defaults to 0.0.
drop_path (float, optional): The drop path rate. Defaults to 0.0.
attn_type (str, optional): Type of attention which should be one of
the following options:
- 'self-attn': Self-attention.
- 'cross-attn': Cross-attention.
Defaults to 'self-attn'.
act_fn (str, optional): The activation function which should be one
of the following options:
- 'ReLU': ReLU activation.
- 'SiLU': SiLU activation.
Defaults to 'SiLU'.
bias (bool, optional): Whether to use bias in linear layers.
Defaults to False.
use_rel_bias (bool, optional): Whether to use relative bias.
Defaults to True.
pos_enc (bool, optional): Whether to use rotary position
embedding. Defaults to False.
Reference:
`Transformer Quality in Linear Time
<https://arxiv.org/abs/2202.10447>`_
"""
def __init__(self,
num_token,
in_token_dims,
out_token_dims,
expansion_factor=2,
s=128,
eps=1e-5,
dropout_rate=0.,
drop_path=0.,
attn_type='self-attn',
act_fn='SiLU',
bias=False,
use_rel_bias=True,
pos_enc=False):
super(RTMCCBlock, self).__init__()
self.s = s
self.num_token = num_token
self.use_rel_bias = use_rel_bias
self.attn_type = attn_type
self.pos_enc = pos_enc
self.drop_path = DropPath(drop_path) \
if drop_path > 0. else nn.Identity()
self.e = int(in_token_dims * expansion_factor)
if use_rel_bias:
if attn_type == 'self-attn':
self.w = nn.Parameter(
torch.rand([2 * num_token - 1], dtype=torch.float))
else:
self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float))
self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float))
self.o = nn.Linear(self.e, out_token_dims, bias=bias)
if attn_type == 'self-attn':
self.uv = nn.Linear(in_token_dims, 2 * self.e + self.s, bias=bias)
self.gamma = nn.Parameter(torch.rand((2, self.s)))
self.beta = nn.Parameter(torch.rand((2, self.s)))
else:
self.uv = nn.Linear(in_token_dims, self.e + self.s, bias=bias)
self.k_fc = nn.Linear(in_token_dims, self.s, bias=bias)
self.v_fc = nn.Linear(in_token_dims, self.e, bias=bias)
nn.init.xavier_uniform_(self.k_fc.weight)
nn.init.xavier_uniform_(self.v_fc.weight)
self.ln = ScaleNorm(in_token_dims, eps=eps)
nn.init.xavier_uniform_(self.uv.weight)
if act_fn == 'SiLU' or act_fn == nn.SiLU:
assert digit_version(TORCH_VERSION) >= digit_version('1.7.0'), \
'SiLU activation requires PyTorch version >= 1.7'
self.act_fn = nn.SiLU(True)
elif act_fn == 'ReLU' or act_fn == nn.ReLU:
self.act_fn = nn.ReLU(True)
else:
raise NotImplementedError
if in_token_dims == out_token_dims:
self.shortcut = True
self.res_scale = Scale(in_token_dims)
else:
self.shortcut = False
self.sqrt_s = math.sqrt(s)
self.dropout_rate = dropout_rate
if dropout_rate > 0.:
self.dropout = nn.Dropout(dropout_rate)
[文档] def rel_pos_bias(self, seq_len, k_len=None):
"""Add relative position bias."""
if self.attn_type == 'self-attn':
t = F.pad(self.w[:2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
r = (2 * seq_len - 1) // 2
t = t[..., r:-r]
else:
a = rope(self.a.repeat(seq_len, 1), dim=0)
b = rope(self.b.repeat(k_len, 1), dim=0)
t = torch.bmm(a, b.permute(0, 2, 1))
return t
def _forward(self, inputs):
"""GAU Forward function."""
if self.attn_type == 'self-attn':
x = inputs
else:
x, k, v = inputs
x = self.ln(x)
# [B, K, in_token_dims] -> [B, K, e + e + s]
uv = self.uv(x)
uv = self.act_fn(uv)
if self.attn_type == 'self-attn':
# [B, K, e + e + s] -> [B, K, e], [B, K, e], [B, K, s]
u, v, base = torch.split(uv, [self.e, self.e, self.s], dim=2)
# [B, K, 1, s] * [1, 1, 2, s] + [2, s] -> [B, K, 2, s]
base = base.unsqueeze(2) * self.gamma[None, None, :] + self.beta
if self.pos_enc:
base = rope(base, dim=1)
# [B, K, 2, s] -> [B, K, s], [B, K, s]
q, k = torch.unbind(base, dim=2)
else:
# [B, K, e + s] -> [B, K, e], [B, K, s]
u, q = torch.split(uv, [self.e, self.s], dim=2)
k = self.k_fc(k) # -> [B, K, s]
v = self.v_fc(v) # -> [B, K, e]
if self.pos_enc:
q = rope(q, 1)
k = rope(k, 1)
# [B, K, s].permute() -> [B, s, K]
# [B, K, s] x [B, s, K] -> [B, K, K]
qk = torch.bmm(q, k.permute(0, 2, 1))
if self.use_rel_bias:
if self.attn_type == 'self-attn':
bias = self.rel_pos_bias(q.size(1))
else:
bias = self.rel_pos_bias(q.size(1), k.size(1))
qk += bias[:, :q.size(1), :k.size(1)]
# [B, K, K]
kernel = torch.square(F.relu(qk / self.sqrt_s))
if self.dropout_rate > 0.:
kernel = self.dropout(kernel)
# [B, K, K] x [B, K, e] -> [B, K, e]
x = u * torch.bmm(kernel, v)
# [B, K, e] -> [B, K, out_token_dims]
x = self.o(x)
return x
[文档] def forward(self, x):
"""Forward function."""
if self.shortcut:
if self.attn_type == 'cross-attn':
res_shortcut = x[0]
else:
res_shortcut = x
main_branch = self.drop_path(self._forward(x))
return self.res_scale(res_shortcut) + main_branch
else:
return self.drop_path(self._forward(x))