Shortcuts

Source code for common.vision.datasets.digits

"""
@author: Junguang Jiang, Baixu Chen
@contact: [email protected], [email protected]
"""
import os
from typing import Optional, Tuple, Any
from .imagelist import ImageList
from ._util import download as download_data, check_exits


[docs]class MNIST(ImageList): """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset. Args: root (str): Root directory of dataset where ``MNIST/processed/training.pt`` and ``MNIST/processed/test.pt`` exist. mode (str): The channel mode for image. Choices includes ``"L"```, ``"RGB"``. Default: ``"L"``` split (str, optional): The dataset split, supports ``train``, or ``test``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/f18f1e115de94644b900/?dl=1"), ("mnist_train_image", "mnist_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/fdf45c75d2e746acba93/?dl=1"), # ("mnist_test_image", "mnist_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/fdf45c75d2e746acba93/?dl=1") ] image_list = { "train": "image_list/mnist_train.txt", "test": "image_list/mnist_test.txt" } CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] def __init__(self, root, mode="L", split='train', download: Optional[bool] = True, **kwargs): assert split in ['train', 'test'] data_list_file = os.path.join(root, self.image_list[split]) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) assert mode in ['L', 'RGB'] self.mode = mode super(MNIST, self).__init__(root, MNIST.CLASSES, data_list_file=data_list_file, **kwargs) def __getitem__(self, index: int) -> Tuple[Any, int]: """ Args: index (int): Index return (tuple): (image, target) where target is index of the target class. """ path, target = self.samples[index] img = self.loader(path).convert(self.mode) if self.transform is not None: img = self.transform(img) if self.target_transform is not None and target is not None: target = self.target_transform(target) return img, target @classmethod def get_classes(self): return MNIST.CLASSES
[docs]class USPS(ImageList): """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset. The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``. The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]`` and make pixel values in ``[0, 255]``. Args: root (str): Root directory of dataset to store``USPS`` data files. mode (str): The channel mode for image. Choices includes ``"L"```, ``"RGB"``. Default: ``"L"``` split (str, optional): The dataset split, supports ``train``, or ``test``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/10ddb319c24e40a08e58/?dl=1"), ("usps_train_image", "usps_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/1d3d7e2540bd4392b715/?dl=1"), # ("usps_test_image", "usps_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/1d3d7e2540bd4392b715/?dl=1") ] image_list = { "train": "image_list/usps_train.txt", "test": "image_list/usps_test.txt" } CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] def __init__(self, root, mode="L", split='train', download: Optional[bool] = True, **kwargs): assert split in ['train', 'test'] data_list_file = os.path.join(root, self.image_list[split]) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) assert mode in ['L', 'RGB'] self.mode = mode super(USPS, self).__init__(root, USPS.CLASSES, data_list_file=data_list_file, **kwargs) def __getitem__(self, index: int) -> Tuple[Any, int]: """ Args: index (int): Index return (tuple): (image, target) where target is index of the target class. """ path, target = self.samples[index] img = self.loader(path).convert(self.mode) if self.transform is not None: img = self.transform(img) if self.target_transform is not None and target is not None: target = self.target_transform(target) return img, target
[docs]class SVHN(ImageList): """`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset. Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which expect the class labels to be in the range `[0, C-1]` .. warning:: This class needs `scipy <https://docs.scipy.org/doc/>`_ to load data from `.mat` format. Args: root (str): Root directory of dataset where directory ``SVHN`` exists. mode (str): The channel mode for image. Choices includes ``"L"```, ``"RGB"``. Default: ``"RGB"``` split (str, optional): The dataset split, supports ``train``, or ``test``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/80a8a06c4a324c59a5e4/?dl=1"), ("svhn_image", "svhn_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/0e48a871e00345eb91a9/?dl=1") ] image_list = "image_list/svhn_balanced.txt" # image_list = "image_list/svhn.txt" CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] def __init__(self, root, mode="L", download: Optional[bool] = True, **kwargs): data_list_file = os.path.join(root, self.image_list) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) assert mode in ['L', 'RGB'] self.mode = mode super(SVHN, self).__init__(root, SVHN.CLASSES, data_list_file=data_list_file, **kwargs) def __getitem__(self, index: int) -> Tuple[Any, int]: """ Args: index (int): Index return (tuple): (image, target) where target is index of the target class. """ path, target = self.samples[index] img = self.loader(path).convert(self.mode) if self.transform is not None: img = self.transform(img) if self.target_transform is not None and target is not None: target = self.target_transform(target) return img, target
class MNISTRGB(MNIST): def __init__(self, root, **kwargs): super(MNISTRGB, self).__init__(root, mode='RGB', **kwargs) class USPSRGB(USPS): def __init__(self, root, **kwargs): super(USPSRGB, self).__init__(root, mode='RGB', **kwargs) class SVHNRGB(SVHN): def __init__(self, root, **kwargs): super(SVHNRGB, self).__init__(root, mode='RGB', **kwargs)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started