# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmpose.registry import MODELS

[文档]@MODELS.register_module() class BCELoss(nn.Module): """Binary Cross Entropy loss. Args: use_target_weight (bool): Option to use weighted loss. Different joint types may have different target weights. loss_weight (float): Weight of the loss. Default: 1.0. with_logits (bool): Whether to use BCEWithLogitsLoss. Default: False. """ def __init__(self, use_target_weight=False, loss_weight=1., with_logits=False): super().__init__() self.criterion = F.binary_cross_entropy if not with_logits\ else F.binary_cross_entropy_with_logits self.use_target_weight = use_target_weight self.loss_weight = loss_weight
[文档] def forward(self, output, target, target_weight=None): """Forward function. Note: - batch_size: N - num_labels: K Args: output (torch.Tensor[N, K]): Output classification. target (torch.Tensor[N, K]): Target classification. target_weight (torch.Tensor[N, K] or torch.Tensor[N]): Weights across different labels. """ if self.use_target_weight: assert target_weight is not None loss = self.criterion(output, target, reduction='none') if target_weight.dim() == 1: target_weight = target_weight[:, None] loss = (loss * target_weight).mean() else: loss = self.criterion(output, target) return loss * self.loss_weight
[文档]@MODELS.register_module() class JSDiscretLoss(nn.Module): """Discrete JS Divergence loss for DSNT with Gaussian Heatmap. Modified from `the official implementation <>`_. Args: use_target_weight (bool): Option to use weighted loss. Different joint types may have different target weights. size_average (bool): Option to average the loss by the batch_size. """ def __init__( self, use_target_weight=True, size_average: bool = True, ): super(JSDiscretLoss, self).__init__() self.use_target_weight = use_target_weight self.size_average = size_average self.kl_loss = nn.KLDivLoss(reduction='none')
[文档] def kl(self, p, q): """Kullback-Leibler Divergence.""" eps = 1e-24 kl_values = self.kl_loss((q + eps).log(), p) return kl_values
[文档] def js(self, pred_hm, gt_hm): """Jensen-Shannon Divergence.""" m = 0.5 * (pred_hm + gt_hm) js_values = 0.5 * (self.kl(pred_hm, m) + self.kl(gt_hm, m)) return js_values
[文档] def forward(self, pred_hm, gt_hm, target_weight=None): """Forward function. Args: pred_hm (torch.Tensor[N, K, H, W]): Predicted heatmaps. gt_hm (torch.Tensor[N, K, H, W]): Target heatmaps. target_weight (torch.Tensor[N, K] or torch.Tensor[N]): Weights across different labels. Returns: torch.Tensor: Loss value. """ if self.use_target_weight: assert target_weight is not None assert pred_hm.ndim >= target_weight.ndim for i in range(pred_hm.ndim - target_weight.ndim): target_weight = target_weight.unsqueeze(-1) loss = self.js(pred_hm * target_weight, gt_hm * target_weight) else: loss = self.js(pred_hm, gt_hm) if self.size_average: loss /= len(gt_hm) return loss.sum()
[文档]@MODELS.register_module() class KLDiscretLoss(nn.Module): """Discrete KL Divergence loss for SimCC with Gaussian Label Smoothing. Modified from `the official implementation. <>`_. Args: beta (float): Temperature factor of Softmax. label_softmax (bool): Whether to use Softmax on labels. use_target_weight (bool): Option to use weighted loss. Different joint types may have different target weights. """ def __init__(self, beta=1.0, label_softmax=False, use_target_weight=True): super(KLDiscretLoss, self).__init__() self.beta = beta self.label_softmax = label_softmax self.use_target_weight = use_target_weight self.log_softmax = nn.LogSoftmax(dim=1) self.kl_loss = nn.KLDivLoss(reduction='none')
[文档] def criterion(self, dec_outs, labels): """Criterion function.""" log_pt = self.log_softmax(dec_outs * self.beta) if self.label_softmax: labels = F.softmax(labels * self.beta, dim=1) loss = torch.mean(self.kl_loss(log_pt, labels), dim=1) return loss
[文档] def forward(self, pred_simcc, gt_simcc, target_weight): """Forward function. Args: pred_simcc (Tuple[Tensor, Tensor]): Predicted SimCC vectors of x-axis and y-axis. gt_simcc (Tuple[Tensor, Tensor]): Target representations. target_weight (torch.Tensor[N, K] or torch.Tensor[N]): Weights across different labels. """ num_joints = pred_simcc[0].size(1) loss = 0 if self.use_target_weight: weight = target_weight.reshape(-1) else: weight = 1. for pred, target in zip(pred_simcc, gt_simcc): pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1, target.size(-1)) loss += self.criterion(pred, target).mul(weight).sum() return loss / num_joints
@MODELS.register_module() class InfoNCELoss(nn.Module): """InfoNCE loss for training a discriminative representation space with a contrastive manner. `Representation Learning with Contrastive Predictive Coding arXiv: <>`_. Args: temperature (float, optional): The temperature to use in the softmax function. Higher temperatures lead to softer probability distributions. Defaults to 1.0. loss_weight (float, optional): The weight to apply to the loss. Defaults to 1.0. """ def __init__(self, temperature: float = 1.0, loss_weight=1.0) -> None: super(InfoNCELoss, self).__init__() assert temperature > 0, f'the argument `temperature` must be ' \ f'positive, but got {temperature}' self.temp = temperature self.loss_weight = loss_weight def forward(self, features: torch.Tensor) -> torch.Tensor: """Computes the InfoNCE loss. Args: features (Tensor): A tensor containing the feature representations of different samples. Returns: Tensor: A tensor of shape (1,) containing the InfoNCE loss. """ n = features.size(0) features_norm = F.normalize(features, dim=1) logits = / self.temp targets = torch.arange(n, dtype=torch.long, device=features.device) loss = F.cross_entropy(logits, targets, reduction='sum') return loss * self.loss_weight
