Shortcuts

Source code for dalib.adaptation.jan

"""
@author: Junguang Jiang
@contact: [email protected]
"""
from typing import Optional, Sequence
import torch
import torch.nn as nn

from common.modules.classifier import Classifier as ClassifierBase
from ..modules.grl import GradientReverseLayer
from ..modules.kernels import GaussianKernel
from .dan import _update_index_matrix


__all__ = ['JointMultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier']



[docs]class JointMultipleKernelMaximumMeanDiscrepancy(nn.Module): r"""The Joint Multiple Kernel Maximum Mean Discrepancy (JMMD) used in `Deep Transfer Learning with Joint Adaptation Networks (ICML 2017) <https://arxiv.org/abs/1605.06636>`_ Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t` of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate activations in layers :math:`\mathcal{L}` as :math:`\{(z_i^{s1}, ..., z_i^{s|\mathcal{L}|})\}_{i=1}^{n_s}` and :math:`\{(z_i^{t1}, ..., z_i^{t|\mathcal{L}|})\}_{i=1}^{n_t}`. The empirical estimate of :math:`\hat{D}_{\mathcal{L}}(P, Q)` is computed as the squared distance between the empirical kernel mean embeddings as .. math:: \hat{D}_{\mathcal{L}}(P, Q) &= \dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{sl}) \\ &+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{tl}, z_j^{tl}) \\ &- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{tl}). \\ Args: kernels (tuple(tuple(torch.nn.Module))): kernel functions, where `kernels[r]` corresponds to kernel :math:`k^{\mathcal{L}[r]}`. linear (bool): whether use the linear version of JAN. Default: False thetas (list(Theta): use adversarial version JAN if not None. Default: None Inputs: - z_s (tuple(tensor)): multiple layers' activations from the source domain, :math:`z^s` - z_t (tuple(tensor)): multiple layers' activations from the target domain, :math:`z^t` Shape: - :math:`z^{sl}` and :math:`z^{tl}`: :math:`(minibatch, *)` where * means any dimension - Outputs: scalar .. note:: Activations :math:`z^{sl}` and :math:`z^{tl}` must have the same shape. .. note:: The kernel values will add up when there are multiple kernels for a certain layer. Examples:: >>> feature_dim = 1024 >>> batch_size = 10 >>> layer1_kernels = (GaussianKernel(alpha=0.5), GaussianKernel(1.), GaussianKernel(2.)) >>> layer2_kernels = (GaussianKernel(1.), ) >>> loss = JointMultipleKernelMaximumMeanDiscrepancy((layer1_kernels, layer2_kernels)) >>> # layer1 features from source domain and target domain >>> z1_s, z1_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> # layer2 features from source domain and target domain >>> z2_s, z2_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> output = loss((z1_s, z2_s), (z1_t, z2_t)) """ def __init__(self, kernels: Sequence[Sequence[nn.Module]], linear: Optional[bool] = True, thetas: Sequence[nn.Module] = None): super(JointMultipleKernelMaximumMeanDiscrepancy, self).__init__() self.kernels = kernels self.index_matrix = None self.linear = linear if thetas: self.thetas = thetas else: self.thetas = [nn.Identity() for _ in kernels] def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor: batch_size = int(z_s[0].size(0)) self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s[0].device) kernel_matrix = torch.ones_like(self.index_matrix) for layer_z_s, layer_z_t, layer_kernels, theta in zip(z_s, z_t, self.kernels, self.thetas): layer_features = torch.cat([layer_z_s, layer_z_t], dim=0) layer_features = theta(layer_features) kernel_matrix *= sum( [kernel(layer_features) for kernel in layer_kernels]) # Add up the matrix of each kernel # Add 2 / (n-1) to make up for the value on the diagonal # to ensure loss is positive in the non-linear version loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1) return loss
class Theta(nn.Module): """ maximize loss respect to :math:`\theta` minimize loss respect to features """ def __init__(self, dim: int): super(Theta, self).__init__() self.grl1 = GradientReverseLayer() self.grl2 = GradientReverseLayer() self.layer1 = nn.Linear(dim, dim) nn.init.eye_(self.layer1.weight) nn.init.zeros_(self.layer1.bias) def forward(self, features: torch.Tensor) -> torch.Tensor: features = self.grl1(features) return self.grl2(self.layer1(features)) class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(0.5) ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started