Shortcuts

Source code for dalib.translation.cyclegan.loss

"""
Modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
@author: Junguang Jiang
@contact: [email protected]
"""
import torch.nn as nn
import torch


[docs]class LeastSquaresGenerativeAdversarialLoss(nn.Module): """ Loss for `Least Squares Generative Adversarial Network (LSGAN) <https://arxiv.org/abs/1611.04076>`_ Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - prediction (tensor): unnormalized discriminator predictions - real (bool): if the ground truth label is for real images or fake images. Default: true .. warning:: Do not use sigmoid as the last layer of Discriminator. """ def __init__(self, reduction='mean'): super(LeastSquaresGenerativeAdversarialLoss, self).__init__() self.mse_loss = nn.MSELoss(reduction=reduction) def forward(self, prediction, real=True): if real: label = torch.ones_like(prediction) else: label = torch.zeros_like(prediction) return self.mse_loss(prediction, label)
[docs]class VanillaGenerativeAdversarialLoss(nn.Module): """ Loss for `Vanilla Generative Adversarial Network <https://arxiv.org/abs/1406.2661>`_ Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - prediction (tensor): unnormalized discriminator predictions - real (bool): if the ground truth label is for real images or fake images. Default: true .. warning:: Do not use sigmoid as the last layer of Discriminator. """ def __init__(self, reduction='mean'): super(VanillaGenerativeAdversarialLoss, self).__init__() self.bce_loss = nn.BCEWithLogitsLoss(reduction=reduction) def forward(self, prediction, real=True): if real: label = torch.ones_like(prediction) else: label = torch.zeros_like(prediction) return self.bce_loss(prediction, label)
[docs]class WassersteinGenerativeAdversarialLoss(nn.Module): """ Loss for `Wasserstein Generative Adversarial Network <https://arxiv.org/abs/1701.07875>`_ Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - prediction (tensor): unnormalized discriminator predictions - real (bool): if the ground truth label is for real images or fake images. Default: true .. warning:: Do not use sigmoid as the last layer of Discriminator. """ def __init__(self, reduction='mean'): super(WassersteinGenerativeAdversarialLoss, self).__init__() self.mse_loss = nn.MSELoss(reduction=reduction) def forward(self, prediction, real=True): if real: return -prediction.mean() else: return prediction.mean()

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started