Source code for dalib.translation.cycada
"""
@author: Junguang Jiang
@contact: [email protected]
"""
import torch.nn as nn
from torch import Tensor
[docs]class SemanticConsistency(nn.Module):
"""
Semantic consistency loss is introduced by
`CyCADA: Cycle-Consistent Adversarial Domain Adaptation (ICML 2018) <https://arxiv.org/abs/1711.03213>`_
This helps to prevent label flipping during image translation.
Args:
ignore_index (tuple, optional): Specifies target values that are ignored
and do not contribute to the input gradient. When :attr:`size_average` is
``True``, the loss is averaged over non-ignored targets. Default: ().
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
be applied, ``'mean'``: the weighted mean of the output is taken,
``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in
the meantime, specifying either of those two args will override
:attr:`reduction`. Default: ``'mean'``
Shape:
- Input: :math:`(N, C)` where `C = number of classes`, or
:math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
in the case of `K`-dimensional loss.
- Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
K-dimensional loss.
- Output: scalar.
If :attr:`reduction` is ``'none'``, then the same size as the target:
:math:`(N)`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case
of K-dimensional loss.
Examples::
>>> loss = SemanticConsistency()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(self, ignore_index=(), reduction='mean'):
super(SemanticConsistency, self).__init__()
self.ignore_index = ignore_index
self.loss = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction)
def forward(self, input: Tensor, target: Tensor) -> Tensor:
for class_idx in self.ignore_index:
target[target == class_idx] = -1
return self.loss(input, target)