Source code for dglib.generalization.coral
"""
@author: Baixu Chen
@contact: [email protected]
"""
import torch
import torch.nn as nn
[docs]class CorrelationAlignmentLoss(nn.Module):
r"""The `Correlation Alignment Loss` in
`Deep CORAL: Correlation Alignment for Deep Domain Adaptation (ECCV 2016) <https://arxiv.org/pdf/1607.01719.pdf>`_.
Given source features :math:`f_S` and target features :math:`f_T`, the covariance matrices are given by
.. math::
C_S = \frac{1}{n_S-1}(f_S^Tf_S-\frac{1}{n_S}(\textbf{1}^Tf_S)^T(\textbf{1}^Tf_S))
.. math::
C_T = \frac{1}{n_T-1}(f_T^Tf_T-\frac{1}{n_T}(\textbf{1}^Tf_T)^T(\textbf{1}^Tf_T))
where :math:`\textbf{1}` denotes a column vector with all elements equal to 1, :math:`n_S, n_T` denotes number of
source and target samples, respectively. We use :math:`d` to denote feature dimension, use
:math:`{\Vert\cdot\Vert}^2_F` to denote the squared matrix `Frobenius norm`. The correlation alignment loss is
given by
.. math::
l_{CORAL} = \frac{1}{4d^2}\Vert C_S-C_T \Vert^2_F
Inputs:
- f_s (tensor): feature representations on source domain, :math:`f^s`
- f_t (tensor): feature representations on target domain, :math:`f^t`
Shape:
- f_s, f_t: :math:`(N, d)` where d means the dimension of input features, :math:`N=n_S=n_T` is mini-batch size.
- Outputs: scalar.
"""
def __init__(self):
super(CorrelationAlignmentLoss, self).__init__()
def forward(self, f_s: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:
mean_s = f_s.mean(0, keepdim=True)
mean_t = f_t.mean(0, keepdim=True)
cent_s = f_s - mean_s
cent_t = f_t - mean_t
cov_s = torch.mm(cent_s.t(), cent_s) / (len(f_s) - 1)
cov_t = torch.mm(cent_t.t(), cent_t) / (len(f_t) - 1)
mean_diff = (mean_s - mean_t).pow(2).mean()
cov_diff = (cov_s - cov_t).pow(2).mean()
return mean_diff + cov_diff