Shortcuts

Source code for dalib.adaptation.advent

"""
@author: Junguang Jiang
@contact: [email protected]
"""
from torch import nn
import torch
import torch.nn.functional as F
import numpy as np


[docs]class Discriminator(nn.Sequential): """ Domain discriminator model from `ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation (CVPR 2019) <https://arxiv.org/abs/1811.12833>`_ Distinguish pixel-by-pixel whether the input predictions come from the source domain or the target domain. The source domain label is 1 and the target domain label is 0. Args: num_classes (int): num of classes in the predictions ndf (int): dimension of the hidden features Shape: - Inputs: :math:`(minibatch, C, H, W)` where :math:`C` is the number of classes - Outputs: :math:`(minibatch, 1, H, W)` """ def __init__(self, num_classes, ndf=64): super(Discriminator, self).__init__( nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1), )
def prob_2_entropy(prob): """ convert probabilistic prediction maps to weighted self-information maps """ n, c, h, w = prob.size() return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c) def bce_loss(y_pred, y_label): y_truth_tensor = torch.FloatTensor(y_pred.size()) y_truth_tensor.fill_(y_label) y_truth_tensor = y_truth_tensor.to(y_pred.get_device()) return F.binary_cross_entropy_with_logits(y_pred, y_truth_tensor)
[docs]class DomainAdversarialEntropyLoss(nn.Module): r"""The `Domain Adversarial Entropy Loss <https://arxiv.org/abs/1811.12833>`_ Minimizing entropy with adversarial learning through training a domain discriminator. Args: domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of predictions. Its input shape is :math:`(minibatch, C, H, W)` and output shape is :math:`(minibatch, 1, H, W)` Inputs: - logits (tensor): logits output of segmentation model - domain_label (str, optional): whether the data comes from source or target. Choices: ['source', 'target']. Default: 'source' Shape: - logits: :math:`(minibatch, C, H, W)` where :math:`C` means the number of classes - Outputs: scalar. Examples:: >>> B, C, H, W = 2, 19, 512, 512 >>> discriminator = Discriminator(num_classes=C) >>> dann = DomainAdversarialEntropyLoss(discriminator) >>> # logits output on source domain and target domain >>> y_s, y_t = torch.randn(B, C, H, W), torch.randn(B, C, H, W) >>> loss = 0.5 * (dann(y_s, "source") + dann(y_t, "target")) """ def __init__(self, discriminator: nn.Module): super(DomainAdversarialEntropyLoss, self).__init__() self.discriminator = discriminator
[docs] def forward(self, logits, domain_label='source'): """ """ assert domain_label in ['source', 'target'] probability = F.softmax(logits, dim=1) entropy = prob_2_entropy(probability) domain_prediciton = self.discriminator(entropy) if domain_label == 'source': return bce_loss(domain_prediciton, 1) else: return bce_loss(domain_prediciton, 0)
[docs] def train(self, mode=True): r"""Sets the discriminator in training mode. In the training mode, all the parameters in discriminator will be set requires_grad=True. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. """ self.discriminator.train(mode) for param in self.discriminator.parameters(): param.requires_grad = mode return self
[docs] def eval(self): r"""Sets the module in evaluation mode. In the training mode, all the parameters in discriminator will be set requires_grad=False. This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`. """ return self.train(False)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started