Source code for common.utils.analysis.a_distance
"""
@author: Junguang Jiang
@contact: [email protected]
"""
from torch.utils.data import TensorDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import SGD
from ..meter import AverageMeter
from ..metric import binary_accuracy
class ANet(nn.Module):
def __init__(self, in_feature):
super(ANet, self).__init__()
self.layer = nn.Linear(in_feature, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.layer(x)
x = self.sigmoid(x)
return x
[docs]def calculate(source_feature: torch.Tensor, target_feature: torch.Tensor,
device, progress=True, training_epochs=10):
"""
Calculate the :math:`\mathcal{A}`-distance, which is a measure for distribution discrepancy.
The definition is :math:`dist_\mathcal{A} = 2 (1-2\epsilon)`, where :math:`\epsilon` is the
test error of a classifier trained to discriminate the source from the target.
Args:
source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`
target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`
device (torch.device)
progress (bool): if True, displays a the progress of training A-Net
training_epochs (int): the number of epochs when training the classifier
Returns:
:math:`\mathcal{A}`-distance
"""
source_label = torch.ones((source_feature.shape[0], 1))
target_label = torch.zeros((target_feature.shape[0], 1))
feature = torch.cat([source_feature, target_feature], dim=0)
label = torch.cat([source_label, target_label], dim=0)
dataset = TensorDataset(feature, label)
length = len(dataset)
train_size = int(0.8 * length)
val_size = length - train_size
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_set, batch_size=2, shuffle=True)
val_loader = DataLoader(val_set, batch_size=8, shuffle=False)
anet = ANet(feature.shape[1]).to(device)
optimizer = SGD(anet.parameters(), lr=0.01)
a_distance = 2.0
for epoch in range(training_epochs):
anet.train()
for (x, label) in train_loader:
x = x.to(device)
label = label.to(device)
anet.zero_grad()
y = anet(x)
loss = F.binary_cross_entropy(y, label)
loss.backward()
optimizer.step()
anet.eval()
meter = AverageMeter("accuracy", ":4.2f")
with torch.no_grad():
for (x, label) in val_loader:
x = x.to(device)
label = label.to(device)
y = anet(x)
acc = binary_accuracy(y, label)
meter.update(acc, x.shape[0])
error = 1 - meter.avg / 100
a_distance = 2 * (1 - 2 * error)
if progress:
print("epoch {} accuracy: {} A-dist: {}".format(epoch, meter.avg, a_distance))
return a_distance