Shortcuts

Source code for dalib.adaptation.osbp

"""
@author: Junguang Jiang
@contact: [email protected]
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F

from common.modules.classifier import Classifier as ClassifierBase
from ..modules.grl import GradientReverseLayer


[docs]class UnknownClassBinaryCrossEntropy(nn.Module): r""" Binary cross entropy loss to make a boundary for unknown samples, proposed by `Open Set Domain Adaptation by Backpropagation (ECCV 2018) <https://arxiv.org/abs/1804.10427>`_. Given a sample on target domain :math:`x_t` and its classifcation outputs :math:`y`, the binary cross entropy loss is defined as .. math:: L_{\text{adv}}(x_t) = -t \text{log}(p(y=C+1|x_t)) - (1-t)\text{log}(1-p(y=C+1|x_t)) where t is a hyper-parameter and C is the number of known classes. Args: t (float): Predefined hyper-parameter. Default: 0.5 Inputs: - y (tensor): classification outputs (before softmax). Shape: - y: :math:`(minibatch, C+1)` where C is the number of known classes. - Outputs: scalar """ def __init__(self, t: Optional[float]=0.5): super(UnknownClassBinaryCrossEntropy, self).__init__() self.t = t def forward(self, y): # y : N x (C+1) softmax_output = F.softmax(y, dim=1) unknown_class_prob = softmax_output[:, -1].contiguous().view(-1, 1) known_class_prob = 1. - unknown_class_prob unknown_target = torch.ones((y.size(0), 1)).to(y.device) * self.t known_target = 1. - unknown_target return - torch.mean(unknown_target * torch.log(unknown_class_prob + 1e-6)) \ - torch.mean(known_target * torch.log(known_class_prob + 1e-6))
class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(), nn.Linear(bottleneck_dim, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) self.grl = GradientReverseLayer() def forward(self, x: torch.Tensor, grad_reverse: Optional[bool] = False): features = self.pool_layer(self.backbone(x)) features = self.bottleneck(features) if grad_reverse: features = self.grl(features) outputs = self.head(features) if self.training: return outputs, features else: return outputs

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started