Shortcuts

Source code for common.vision.datasets.regression.mpi3d

"""
@author: Junguang Jiang
@contact: [email protected]
"""
from typing import Optional, Sequence
import os
from .._util import download as download_data, check_exits
from .image_regression import ImageRegression


[docs]class MPI3D(ImageRegression): """`MPI3D <https://arxiv.org/abs/1906.03292>`_ Dataset. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'C'``: Color, \ ``'N'``: Noisy and ``'S'``: Scream. split (str, optional): The dataset split, supports ``train``, or ``test``. factors (sequence[str]): Factors selected. Default: ('horizontal axis', 'vertical axis'). download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: real/ ... realistic/ toy/ image_list/ real_train.txt realistic_train.txt toy_train.txt real_test.txt realistic_test.txt toy_test.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/f0ff24df967b42479d9e/?dl=1"), ("real", "real.tgz", "https://cloud.tsinghua.edu.cn/f/04c1318555fc4283862b/?dl=1"), ("realistic", "realistic.tgz", "https://cloud.tsinghua.edu.cn/f/2c0f7dacc73148cea593/?dl=1"), ("toy", "toy.tgz", "https://cloud.tsinghua.edu.cn/f/6327912a50374e20af95/?dl=1"), ] image_list = { "RL": "real", "RC": "realistic", "T": "toy" } FACTORS = ('horizontal axis', 'vertical axis') def __init__(self, root: str, task: str, split: Optional[str] = 'train', factors: Sequence[str] = ('horizontal axis', 'vertical axis'), download: Optional[bool] = True, target_transform=None, **kwargs): assert task in self.image_list assert split in ['train', 'test'] for factor in factors: assert factor in self.FACTORS factor_index = [self.FACTORS.index(factor) for factor in factors] if target_transform is None: target_transform = lambda x: x[list(factor_index)] / 39. else: target_transform = lambda x: target_transform(x[list(factor_index)]) / 39. data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split)) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(MPI3D, self).__init__(root, factors, data_list_file=data_list_file, target_transform=target_transform, **kwargs)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started