Shortcuts

Source code for common.vision.datasets.openset

"""
@author: Junguang Jiang
@contact: [email protected]
"""
from ..imagelist import ImageList
from ..office31 import Office31
from ..officehome import OfficeHome
from ..visda2017 import VisDA2017

from typing import Optional, ClassVar, Sequence
from copy import deepcopy


__all__ = ['Office31', 'OfficeHome', "VisDA2017"]


[docs]def open_set(dataset_class: ClassVar, public_classes: Sequence[str], private_classes: Optional[Sequence[str]] = ()) -> ClassVar: """ Convert a dataset into its open-set version. In other words, those samples which doesn't belong to `private_classes` will be marked as "unknown". Be aware that `open_set` will change the label number of each category. Args: dataset_class (class): Dataset class. Only subclass of ``ImageList`` can be open-set. public_classes (sequence[str]): A sequence of which categories need to be kept in the open-set dataset.\ Each element of `public_classes` must belong to the `classes` list of `dataset_class`. private_classes (sequence[str], optional): A sequence of which categories need to be marked as "unknown" \ in the open-set dataset. Each element of `private_classes` must belong to the `classes` list of \ `dataset_class`. Default: (). Examples:: >>> public_classes = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard'] >>> private_classes = ['laptop_computer', 'monitor', 'mouse', 'mug', 'projector'] >>> # create a open-set dataset class which has classes >>> # 'back_pack', 'bike', 'calculator', 'headphones', 'keyboard' and 'unknown'. >>> OpenSetOffice31 = open_set(Office31, public_classes, private_classes) >>> # create an instance of the open-set dataset >>> dataset = OpenSetDataset(root="data/office31", task="A") """ if not (issubclass(dataset_class, ImageList)): raise Exception("Only subclass of ImageList can be openset") class OpenSetDataset(dataset_class): def __init__(self, **kwargs): super(OpenSetDataset, self).__init__(**kwargs) samples = [] all_classes = list(deepcopy(public_classes)) + ["unknown"] for (path, label) in self.samples: class_name = self.classes[label] if class_name in public_classes: samples.append((path, all_classes.index(class_name))) elif class_name in private_classes: samples.append((path, all_classes.index("unknown"))) self.samples = samples self.classes = all_classes self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} return OpenSetDataset
[docs]def default_open_set(dataset_class: ClassVar, source: bool) -> ClassVar: """ Default open-set used in some paper. Args: dataset_class (class): Dataset class. Currently, dataset_class must be one of :class:`~common.vision.datasets.office31.Office31`, :class:`~common.vision.datasets.officehome.OfficeHome`, :class:`~common.vision.datasets.visda2017.VisDA2017`, source (bool): Whether the dataset is used for source domain or not. """ if dataset_class == Office31: public_classes = Office31.CLASSES[:20] if source: private_classes = () else: private_classes = Office31.CLASSES[20:] elif dataset_class == OfficeHome: public_classes = sorted(OfficeHome.CLASSES)[:25] if source: private_classes = () else: private_classes = sorted(OfficeHome.CLASSES)[25:] elif dataset_class == VisDA2017: public_classes = ('bicycle', 'bus', 'car', 'motorcycle', 'train', 'truck') if source: private_classes = () else: private_classes = ('aeroplane', 'horse', 'knife', 'person', 'plant', 'skateboard') else: raise NotImplementedError("Unknown openset domain adaptation dataset: {}".format(dataset_class.__name__)) return open_set(dataset_class, public_classes, private_classes)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started