Source code for common.vision.datasets.regression.mpi3d
"""
@author: Junguang Jiang
@contact: [email protected]
"""
from typing import Optional, Sequence
import os
from .._util import download as download_data, check_exits
from .image_regression import ImageRegression
[docs]class MPI3D(ImageRegression):
"""`MPI3D <https://arxiv.org/abs/1906.03292>`_ Dataset.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'C'``: Color, \
``'N'``: Noisy and ``'S'``: Scream.
split (str, optional): The dataset split, supports ``train``, or ``test``.
factors (sequence[str]): Factors selected. Default: ('horizontal axis', 'vertical axis').
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
real/
...
realistic/
toy/
image_list/
real_train.txt
realistic_train.txt
toy_train.txt
real_test.txt
realistic_test.txt
toy_test.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/f0ff24df967b42479d9e/?dl=1"),
("real", "real.tgz", "https://cloud.tsinghua.edu.cn/f/04c1318555fc4283862b/?dl=1"),
("realistic", "realistic.tgz", "https://cloud.tsinghua.edu.cn/f/2c0f7dacc73148cea593/?dl=1"),
("toy", "toy.tgz", "https://cloud.tsinghua.edu.cn/f/6327912a50374e20af95/?dl=1"),
]
image_list = {
"RL": "real",
"RC": "realistic",
"T": "toy"
}
FACTORS = ('horizontal axis', 'vertical axis')
def __init__(self, root: str, task: str, split: Optional[str] = 'train',
factors: Sequence[str] = ('horizontal axis', 'vertical axis'),
download: Optional[bool] = True, target_transform=None, **kwargs):
assert task in self.image_list
assert split in ['train', 'test']
for factor in factors:
assert factor in self.FACTORS
factor_index = [self.FACTORS.index(factor) for factor in factors]
if target_transform is None:
target_transform = lambda x: x[list(factor_index)] / 39.
else:
target_transform = lambda x: target_transform(x[list(factor_index)]) / 39.
data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split))
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(MPI3D, self).__init__(root, factors, data_list_file=data_list_file, target_transform=target_transform, **kwargs)