Shortcuts

Source code for dglib.generalization.groupdro

"""
Modified from https://github.com/facebookresearch/DomainBed
@author: Baixu Chen
@contact: [email protected]
"""
import torch


[docs]class AutomaticUpdateDomainWeightModule(object): r""" Maintaining group weight based on loss history of all domains according to `Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (ICLR 2020) <https://arxiv.org/pdf/1911.08731.pdf>`_. Suppose we have :math:`N` domains. During each iteration, we first calculate unweighted loss among all domains, resulting in :math:`loss\in R^N`. Then we update domain weight by .. math:: w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N] where :math:`\eta` is the hyper parameter which ensures smoother change of weight. As :math:`w \in R^N` denotes a distribution, we `normalize` :math:`w` by its sum. At last, weighted loss is calculated as our objective .. math:: objective = \sum_{k=1}^N w_k * loss_k Args: num_domains (int): The number of source domains. eta (float): Hyper parameter eta. device (torch.device): The device to run on. """ def __init__(self, num_domains: int, eta: float, device): self.domain_weight = torch.ones(num_domains).to(device) / num_domains self.eta = eta
[docs] def get_domain_weight(self, sampled_domain_idxes): """Get domain weight to calculate final objective. Inputs: - sampled_domain_idxes (list): sampled domain indexes in current mini-batch Shape: - sampled_domain_idxes: :math:`(D, )` where D means the number of sampled domains in current mini-batch - Outputs: :math:`(D, )` """ domain_weight = self.domain_weight[sampled_domain_idxes] domain_weight = domain_weight / domain_weight.sum() return domain_weight
[docs] def update(self, sampled_domain_losses: torch.Tensor, sampled_domain_idxes): """Update domain weight using loss of current mini-batch. Inputs: - sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch - sampled_domain_idxes (list): sampled domain indexes in current mini-batch Shape: - sampled_domain_losses: :math:`(D, )` where D means the number of sampled domains in current mini-batch - sampled_domain_idxes: :math:`(D, )` """ sampled_domain_losses = sampled_domain_losses.detach() for loss, idx in zip(sampled_domain_losses, sampled_domain_idxes): self.domain_weight[idx] *= (self.eta * loss).exp()

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started