Source code for dalib.adaptation.jan
"""
@author: Junguang Jiang
@contact: [email protected]
"""
from typing import Optional, Sequence
import torch
import torch.nn as nn
from common.modules.classifier import Classifier as ClassifierBase
from ..modules.grl import GradientReverseLayer
from ..modules.kernels import GaussianKernel
from .dan import _update_index_matrix
__all__ = ['JointMultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier']
[docs]class JointMultipleKernelMaximumMeanDiscrepancy(nn.Module):
r"""The Joint Multiple Kernel Maximum Mean Discrepancy (JMMD) used in
`Deep Transfer Learning with Joint Adaptation Networks (ICML 2017) <https://arxiv.org/abs/1605.06636>`_
Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t`
of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate
activations in layers :math:`\mathcal{L}` as :math:`\{(z_i^{s1}, ..., z_i^{s|\mathcal{L}|})\}_{i=1}^{n_s}` and
:math:`\{(z_i^{t1}, ..., z_i^{t|\mathcal{L}|})\}_{i=1}^{n_t}`. The empirical estimate of
:math:`\hat{D}_{\mathcal{L}}(P, Q)` is computed as the squared distance between the empirical kernel mean
embeddings as
.. math::
\hat{D}_{\mathcal{L}}(P, Q) &=
\dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{sl}) \\
&+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{tl}, z_j^{tl}) \\
&- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{tl}). \\
Args:
kernels (tuple(tuple(torch.nn.Module))): kernel functions, where `kernels[r]` corresponds to kernel :math:`k^{\mathcal{L}[r]}`.
linear (bool): whether use the linear version of JAN. Default: False
thetas (list(Theta): use adversarial version JAN if not None. Default: None
Inputs:
- z_s (tuple(tensor)): multiple layers' activations from the source domain, :math:`z^s`
- z_t (tuple(tensor)): multiple layers' activations from the target domain, :math:`z^t`
Shape:
- :math:`z^{sl}` and :math:`z^{tl}`: :math:`(minibatch, *)` where * means any dimension
- Outputs: scalar
.. note::
Activations :math:`z^{sl}` and :math:`z^{tl}` must have the same shape.
.. note::
The kernel values will add up when there are multiple kernels for a certain layer.
Examples::
>>> feature_dim = 1024
>>> batch_size = 10
>>> layer1_kernels = (GaussianKernel(alpha=0.5), GaussianKernel(1.), GaussianKernel(2.))
>>> layer2_kernels = (GaussianKernel(1.), )
>>> loss = JointMultipleKernelMaximumMeanDiscrepancy((layer1_kernels, layer2_kernels))
>>> # layer1 features from source domain and target domain
>>> z1_s, z1_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> # layer2 features from source domain and target domain
>>> z2_s, z2_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> output = loss((z1_s, z2_s), (z1_t, z2_t))
"""
def __init__(self, kernels: Sequence[Sequence[nn.Module]], linear: Optional[bool] = True, thetas: Sequence[nn.Module] = None):
super(JointMultipleKernelMaximumMeanDiscrepancy, self).__init__()
self.kernels = kernels
self.index_matrix = None
self.linear = linear
if thetas:
self.thetas = thetas
else:
self.thetas = [nn.Identity() for _ in kernels]
def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor:
batch_size = int(z_s[0].size(0))
self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s[0].device)
kernel_matrix = torch.ones_like(self.index_matrix)
for layer_z_s, layer_z_t, layer_kernels, theta in zip(z_s, z_t, self.kernels, self.thetas):
layer_features = torch.cat([layer_z_s, layer_z_t], dim=0)
layer_features = theta(layer_features)
kernel_matrix *= sum(
[kernel(layer_features) for kernel in layer_kernels]) # Add up the matrix of each kernel
# Add 2 / (n-1) to make up for the value on the diagonal
# to ensure loss is positive in the non-linear version
loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1)
return loss
class Theta(nn.Module):
"""
maximize loss respect to :math:`\theta`
minimize loss respect to features
"""
def __init__(self, dim: int):
super(Theta, self).__init__()
self.grl1 = GradientReverseLayer()
self.grl2 = GradientReverseLayer()
self.layer1 = nn.Linear(dim, dim)
nn.init.eye_(self.layer1.weight)
nn.init.zeros_(self.layer1.bias)
def forward(self, features: torch.Tensor) -> torch.Tensor:
features = self.grl1(features)
return self.grl2(self.layer1(features))
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5)
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)