Source code for common.utils.data
"""
@author: Junguang Jiang
@contact: [email protected]
"""
import itertools
import random
import numpy as np
import torch
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, Dataset
from typing import TypeVar, Iterable, Dict, List
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
[docs]def send_to_device(tensor, device):
"""
Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.
Args:
tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`):
The data to send to a given device.
device (:obj:`torch.device`):
The device to send the data to
Returns:
The same data structure as :obj:`tensor` with all tensors sent to the proper device.
"""
if isinstance(tensor, (list, tuple)):
return type(tensor)(send_to_device(t, device) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: send_to_device(v, device) for k, v in tensor.items()})
elif not hasattr(tensor, "to"):
return tensor
return tensor.to(device)
[docs]class ForeverDataIterator:
r"""A data iterator that will never stop producing data"""
def __init__(self, data_loader: DataLoader, device=None):
self.data_loader = data_loader
self.iter = iter(self.data_loader)
self.device = device
def __next__(self):
try:
data = next(self.iter)
if self.device is not None:
data = send_to_device(data, self.device)
except StopIteration:
self.iter = iter(self.data_loader)
data = next(self.iter)
if self.device is not None:
data = send_to_device(data, self.device)
return data
def __len__(self):
return len(self.data_loader)
[docs]class RandomMultipleGallerySampler(Sampler):
r"""Sampler from `In defense of the Triplet Loss for Person Re-Identification
(ICCV 2017) <https://arxiv.org/pdf/1703.07737v2.pdf>`_. Assume there are :math:`N` identities in the dataset, this
implementation simply samples :math:`K` images for every identity to form an iter of size :math:`N\times K`. During
training, we will call ``__iter__`` method of pytorch dataloader once we reach a ``StopIteration``, this guarantees
every image in the dataset will eventually be selected and we are not wasting any training data.
Args:
dataset(list): each element of this list is a tuple (image_path, person_id, camera_id)
num_instances(int, optional): number of images to sample for every identity (:math:`K` here)
"""
def __init__(self, dataset, num_instances=4):
super(RandomMultipleGallerySampler, self).__init__(dataset)
self.dataset = dataset
self.num_instances = num_instances
self.idx_to_pid = {}
self.cid_list_per_pid = {}
self.idx_list_per_pid = {}
for idx, (_, pid, cid) in enumerate(dataset):
if pid not in self.cid_list_per_pid:
self.cid_list_per_pid[pid] = []
self.idx_list_per_pid[pid] = []
self.idx_to_pid[idx] = pid
self.cid_list_per_pid[pid].append(cid)
self.idx_list_per_pid[pid].append(idx)
self.pid_list = list(self.idx_list_per_pid.keys())
self.num_samples = len(self.pid_list)
def __len__(self):
return self.num_samples * self.num_instances
def __iter__(self):
def select_idxes(element_list, target_element):
assert isinstance(element_list, list)
return [i for i, element in enumerate(element_list) if element != target_element]
pid_idxes = torch.randperm(len(self.pid_list)).tolist()
final_idxes = []
for perm_id in pid_idxes:
i = random.choice(self.idx_list_per_pid[self.pid_list[perm_id]])
_, _, cid = self.dataset[i]
final_idxes.append(i)
pid_i = self.idx_to_pid[i]
cid_list = self.cid_list_per_pid[pid_i]
idx_list = self.idx_list_per_pid[pid_i]
selected_cid_list = select_idxes(cid_list, cid)
if selected_cid_list:
if len(selected_cid_list) >= self.num_instances:
cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=False)
else:
cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=True)
for cid_idx in cid_idxes:
final_idxes.append(idx_list[cid_idx])
else:
selected_idxes = select_idxes(idx_list, i)
if not selected_idxes:
continue
if len(selected_idxes) >= self.num_instances:
pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=False)
else:
pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=True)
for pid_idx in pid_idxes:
final_idxes.append(idx_list[pid_idx])
return iter(final_idxes)
[docs]class CombineDataset(Dataset[T_co]):
r"""Dataset as a combination of multiple datasets.
The element of each dataset must be a list, and the i-th element of the combined dataset
is a list splicing of the i-th element of each sub dataset.
The length of the combined dataset is the minimum of the lengths of all sub datasets.
Arguments:
datasets (sequence): List of datasets to be concatenated
"""
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(CombineDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
self.datasets = list(datasets)
def __len__(self):
return min([len(d) for d in self.datasets])
def __getitem__(self, idx):
return list(itertools.chain(*[d[idx] for d in self.datasets]))
[docs]def concatenate(tensors):
"""concatenate multiple batches into one batch.
``tensors`` can be :class:`torch.Tensor`, List or Dict, but they must be the same data format.
"""
if isinstance(tensors[0], torch.Tensor):
return torch.cat(tensors, dim=0)
elif isinstance(tensors[0], List):
ret = []
for i in range(len(tensors[0])):
ret.append(concatenate([t[i] for t in tensors]))
return ret
elif isinstance(tensors[0], Dict):
ret = dict()
for k in tensors[0].keys():
ret[k] = concatenate([t[k] for t in tensors])
return ret