Source code for dglib.modules.sampler
"""
@author: Baixu Chen
@contact: [email protected]
"""
import random
import copy
import numpy as np
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data.sampler import Sampler
[docs]class DefaultSampler(Sampler):
r"""Traverse all :math:`N` domains, randomly select :math:`K` samples in each domain to form a mini-batch of size
:math:`N\times K`.
Args:
data_source (ConcatDataset): dataset that contains data from multiple domains
batch_size (int): mini-batch size (:math:`N\times K` here)
"""
def __init__(self, data_source: ConcatDataset, batch_size: int):
super(Sampler, self).__init__()
self.num_all_domains = len(data_source.cumulative_sizes)
self.sample_idxes_per_domain = []
start = 0
for end in data_source.cumulative_sizes:
idxes = [idx for idx in range(start, end)]
self.sample_idxes_per_domain.append(idxes)
start = end
assert batch_size % self.num_all_domains == 0
self.batch_size_per_domain = batch_size // self.num_all_domains
self.length = len(list(self.__iter__()))
def __iter__(self):
sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain)
final_idxes = []
stop_flag = False
while not stop_flag:
for domain in range(self.num_all_domains):
sample_idxes = sample_idxes_per_domain[domain]
selected_idxes = random.sample(sample_idxes, self.batch_size_per_domain)
final_idxes.extend(selected_idxes)
for idx in selected_idxes:
sample_idxes_per_domain[domain].remove(idx)
remaining_size = len(sample_idxes_per_domain[domain])
if remaining_size < self.batch_size_per_domain:
stop_flag = True
return iter(final_idxes)
def __len__(self):
return self.length
[docs]class RandomDomainSampler(Sampler):
r"""Randomly sample :math:`N` domains, then randomly select :math:`K` samples in each domain to form a mini-batch of
size :math:`N\times K`.
Args:
data_source (ConcatDataset): dataset that contains data from multiple domains
batch_size (int): mini-batch size (:math:`N\times K` here)
n_domains_per_batch (int): number of domains to select in a single mini-batch (:math:`N` here)
"""
def __init__(self, data_source: ConcatDataset, batch_size: int, n_domains_per_batch: int):
super(Sampler, self).__init__()
self.n_domains_in_dataset = len(data_source.cumulative_sizes)
self.n_domains_per_batch = n_domains_per_batch
assert self.n_domains_in_dataset >= self.n_domains_per_batch
self.sample_idxes_per_domain = []
start = 0
for end in data_source.cumulative_sizes:
idxes = [idx for idx in range(start, end)]
self.sample_idxes_per_domain.append(idxes)
start = end
assert batch_size % n_domains_per_batch == 0
self.batch_size_per_domain = batch_size // n_domains_per_batch
self.length = len(list(self.__iter__()))
def __iter__(self):
sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain)
domain_idxes = [idx for idx in range(self.n_domains_in_dataset)]
final_idxes = []
stop_flag = False
while not stop_flag:
selected_domains = random.sample(domain_idxes, self.n_domains_per_batch)
for domain in selected_domains:
sample_idxes = sample_idxes_per_domain[domain]
if len(sample_idxes) < self.batch_size_per_domain:
selected_idxes = np.random.choice(sample_idxes, self.batch_size_per_domain, replace=True)
else:
selected_idxes = random.sample(sample_idxes, self.batch_size_per_domain)
final_idxes.extend(selected_idxes)
for idx in selected_idxes:
if idx in sample_idxes_per_domain[domain]:
sample_idxes_per_domain[domain].remove(idx)
remaining_size = len(sample_idxes_per_domain[domain])
if remaining_size < self.batch_size_per_domain:
stop_flag = True
return iter(final_idxes)
def __len__(self):
return self.length