Shortcuts

Source code for dalib.adaptation.pada

"""
@author: Junguang Jiang
@contact: [email protected]
"""
from typing import Optional, List, Tuple

from torch.utils.data.dataloader import DataLoader
import torch.nn as nn
import torch
import torch.nn.functional as F


[docs]class AutomaticUpdateClassWeightModule(object): r""" Calculating class weight based on the output of classifier. See ``ClassWeightModule`` about the details of the calculation. Every N iterations, the class weight is updated automatically. Args: update_steps (int): N, the number of iterations to update class weight. data_loader (torch.utils.data.DataLoader): The data loader from which we can collect classification outputs. classifier (torch.nn.Module): Classifier. num_classes (int): Number of classes. device (torch.device): The device to run classifier. temperature (float, optional): T, temperature in ClassWeightModule. Default: 0.1 partial_classes_index (list[int], optional): The index of partial classes. Note that this parameter is \ just for debugging, since in real-world dataset, we have no access to the index of partial classes. \ Default: None. Examples:: >>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> class_weight_module.step() >>> # weight for F.cross_entropy >>> w_c = class_weight_module.get_class_weight_for_cross_entropy_loss() >>> # weight for dalib.addaptation.dann.DomainAdversarialLoss >>> w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss() """ def __init__(self, update_steps: int, data_loader: DataLoader, classifier: nn.Module, num_classes: int, device: torch.device, temperature: Optional[float] = 0.1, partial_classes_index: Optional[List[int]] = None): self.update_steps = update_steps self.data_loader = data_loader self.classifier = classifier self.device = device self.class_weight_module = ClassWeightModule(temperature) self.class_weight = torch.ones(num_classes).to(device) self.num_steps = 0 self.partial_classes_index = partial_classes_index if partial_classes_index is not None: self.non_partial_classes_index = [c for c in range(num_classes) if c not in partial_classes_index] def step(self): self.num_steps += 1 if self.num_steps % self.update_steps == 0: all_outputs = collect_classification_results(self.data_loader, self.classifier, self.device) self.class_weight = self.class_weight_module(all_outputs)
[docs] def get_class_weight_for_cross_entropy_loss(self): """ Outputs: weight for F.cross_entropy Shape: :math:`(C, )` where C means the number of classes. """ return self.class_weight
[docs] def get_class_weight_for_adversarial_loss(self, source_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Outputs: - w_s: source weight for :py:class:`~dalib.adaptation.dann.DomainAdversarialLoss` - w_t: target weight for :py:class:`~dalib.adaptation.dann.DomainAdversarialLoss` Shape: - w_s: :math:`(minibatch, )` - w_t: :math:`(minibatch, )` """ class_weight_adv_source = self.class_weight[source_labels] class_weight_adv_target = torch.ones_like(class_weight_adv_source) * class_weight_adv_source.mean() return class_weight_adv_source, class_weight_adv_target
[docs] def get_partial_classes_weight(self): """ Get class weight averaged on the partial classes and non-partial classes respectively. .. warning:: This function is just for debugging, since in real-world dataset, we have no access to the index of \ partial classes and this function will throw an error when `partial_classes_index` is None. """ assert self.partial_classes_index is not None return torch.mean(self.class_weight[self.partial_classes_index]), torch.mean( self.class_weight[self.non_partial_classes_index])
[docs]class ClassWeightModule(nn.Module): r""" Calculating class weight based on the output of classifier. Introduced by `Partial Adversarial Domain Adaptation (ECCV 2018) <https://arxiv.org/abs/1808.04205>`_ Given classification logits outputs :math:`\{\hat{y}_i\}_{i=1}^n`, where :math:`n` is the dataset size, the weight indicating the contribution of each class to the training can be calculated as follows .. math:: \mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}\text{softmax}( \hat{y}_i / T), where :math:`\mathcal{\gamma}` is a :math:`|\mathcal{C}|`-dimensional weight vector quantifying the contribution of each class and T is a hyper-parameters called temperature. In practice, it's possible that some of the weights are very small, thus, we normalize weight :math:`\mathcal{\gamma}` by dividing its largest element, i.e. :math:`\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})` Args: temperature (float, optional): hyper-parameters :math:`T`. Default: 0.1 Shape: - Inputs: (minibatch, :math:`|\mathcal{C}|`) - Outputs: (:math:`|\mathcal{C}|`,) """ def __init__(self, temperature: Optional[float] = 0.1): super(ClassWeightModule, self).__init__() self.temperature = temperature def forward(self, outputs: torch.Tensor): outputs.detach_() softmax_outputs = F.softmax(outputs / self.temperature, dim=1) class_weight = torch.mean(softmax_outputs, dim=0) class_weight = class_weight / torch.max(class_weight) class_weight = class_weight.view(-1) return class_weight
[docs]def collect_classification_results(data_loader: DataLoader, classifier: nn.Module, device: torch.device) -> torch.Tensor: """ Fetch data from `data_loader`, and then use `classifier` to collect classification results Args: data_loader (torch.utils.data.DataLoader): Data loader. classifier (torch.nn.Module): A classifier. device (torch.device) Returns: Classification results in shape (len(data_loader), :math:`|\mathcal{C}|`). """ training = classifier.training classifier.eval() all_outputs = [] with torch.no_grad(): for i, (images, target) in enumerate(data_loader): images = images.to(device) output = classifier(images) all_outputs.append(output) classifier.train(training) return torch.cat(all_outputs, dim=0)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started