Shortcuts

Source code for dglib.modules.sampler

"""
@author: Baixu Chen
@contact: [email protected]
"""
import random
import copy
import numpy as np
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data.sampler import Sampler


[docs]class DefaultSampler(Sampler): r"""Traverse all :math:`N` domains, randomly select :math:`K` samples in each domain to form a mini-batch of size :math:`N\times K`. Args: data_source (ConcatDataset): dataset that contains data from multiple domains batch_size (int): mini-batch size (:math:`N\times K` here) """ def __init__(self, data_source: ConcatDataset, batch_size: int): super(Sampler, self).__init__() self.num_all_domains = len(data_source.cumulative_sizes) self.sample_idxes_per_domain = [] start = 0 for end in data_source.cumulative_sizes: idxes = [idx for idx in range(start, end)] self.sample_idxes_per_domain.append(idxes) start = end assert batch_size % self.num_all_domains == 0 self.batch_size_per_domain = batch_size // self.num_all_domains self.length = len(list(self.__iter__())) def __iter__(self): sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain) final_idxes = [] stop_flag = False while not stop_flag: for domain in range(self.num_all_domains): sample_idxes = sample_idxes_per_domain[domain] selected_idxes = random.sample(sample_idxes, self.batch_size_per_domain) final_idxes.extend(selected_idxes) for idx in selected_idxes: sample_idxes_per_domain[domain].remove(idx) remaining_size = len(sample_idxes_per_domain[domain]) if remaining_size < self.batch_size_per_domain: stop_flag = True return iter(final_idxes) def __len__(self): return self.length
[docs]class RandomDomainSampler(Sampler): r"""Randomly sample :math:`N` domains, then randomly select :math:`K` samples in each domain to form a mini-batch of size :math:`N\times K`. Args: data_source (ConcatDataset): dataset that contains data from multiple domains batch_size (int): mini-batch size (:math:`N\times K` here) n_domains_per_batch (int): number of domains to select in a single mini-batch (:math:`N` here) """ def __init__(self, data_source: ConcatDataset, batch_size: int, n_domains_per_batch: int): super(Sampler, self).__init__() self.n_domains_in_dataset = len(data_source.cumulative_sizes) self.n_domains_per_batch = n_domains_per_batch assert self.n_domains_in_dataset >= self.n_domains_per_batch self.sample_idxes_per_domain = [] start = 0 for end in data_source.cumulative_sizes: idxes = [idx for idx in range(start, end)] self.sample_idxes_per_domain.append(idxes) start = end assert batch_size % n_domains_per_batch == 0 self.batch_size_per_domain = batch_size // n_domains_per_batch self.length = len(list(self.__iter__())) def __iter__(self): sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain) domain_idxes = [idx for idx in range(self.n_domains_in_dataset)] final_idxes = [] stop_flag = False while not stop_flag: selected_domains = random.sample(domain_idxes, self.n_domains_per_batch) for domain in selected_domains: sample_idxes = sample_idxes_per_domain[domain] if len(sample_idxes) < self.batch_size_per_domain: selected_idxes = np.random.choice(sample_idxes, self.batch_size_per_domain, replace=True) else: selected_idxes = random.sample(sample_idxes, self.batch_size_per_domain) final_idxes.extend(selected_idxes) for idx in selected_idxes: if idx in sample_idxes_per_domain[domain]: sample_idxes_per_domain[domain].remove(idx) remaining_size = len(sample_idxes_per_domain[domain]) if remaining_size < self.batch_size_per_domain: stop_flag = True return iter(final_idxes) def __len__(self): return self.length

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started