Shortcuts

Source code for common.vision.models.keypoint_detection.loss

"""
Modified from https://github.com/microsoft/human-pose-estimation.pytorch
@author: Junguang Jiang
@contact: [email protected]
"""
import torch.nn as nn
import torch.nn.functional as F


[docs]class JointsMSELoss(nn.Module): """ Typical MSE loss for keypoint detection. Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output. Default: ``'mean'`` Inputs: - output (tensor): heatmap predictions - target (tensor): heatmap labels - target_weight (tensor): whether the keypoint is visible. All keypoint is visible if None. Default: None. Shape: - output: :math:`(minibatch, K, H, W)` where K means the number of keypoints, H and W is the height and width of the heatmap respectively. - target: :math:`(minibatch, K, H, W)`. - target_weight: :math:`(minibatch, K)`. - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, K)`. """ def __init__(self, reduction='mean'): super(JointsMSELoss, self).__init__() self.criterion = nn.MSELoss(reduction='none') self.reduction = reduction def forward(self, output, target, target_weight=None): B, K, _, _ = output.shape heatmaps_pred = output.reshape((B, K, -1)) heatmaps_gt = target.reshape((B, K, -1)) loss = self.criterion(heatmaps_pred, heatmaps_gt) * 0.5 if target_weight is not None: loss = loss * target_weight.view((B, K, 1)) if self.reduction == 'mean': return loss.mean() elif self.reduction == 'none': return loss.mean(dim=-1)
[docs]class JointsKLLoss(nn.Module): """ KL Divergence for keypoint detection proposed by `Regressive Domain Adaptation for Unsupervised Keypoint Detection <https://arxiv.org/abs/2103.06175>`_. Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output. Default: ``'mean'`` Inputs: - output (tensor): heatmap predictions - target (tensor): heatmap labels - target_weight (tensor): whether the keypoint is visible. All keypoint is visible if None. Default: None. Shape: - output: :math:`(minibatch, K, H, W)` where K means the number of keypoints, H and W is the height and width of the heatmap respectively. - target: :math:`(minibatch, K, H, W)`. - target_weight: :math:`(minibatch, K)`. - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, K)`. """ def __init__(self, reduction='mean', epsilon=0.): super(JointsKLLoss, self).__init__() self.criterion = nn.KLDivLoss(reduction='none') self.reduction = reduction self.epsilon = epsilon def forward(self, output, target, target_weight=None): B, K, _, _ = output.shape heatmaps_pred = output.reshape((B, K, -1)) heatmaps_pred = F.log_softmax(heatmaps_pred, dim=-1) heatmaps_gt = target.reshape((B, K, -1)) heatmaps_gt = heatmaps_gt + self.epsilon heatmaps_gt = heatmaps_gt / heatmaps_gt.sum(dim=-1, keepdims=True) loss = self.criterion(heatmaps_pred, heatmaps_gt).sum(dim=-1) if target_weight is not None: loss = loss * target_weight.view((B, K)) if self.reduction == 'mean': return loss.mean() elif self.reduction == 'none': return loss.mean(dim=-1)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started