Source code for common.vision.datasets.openset
"""
@author: Junguang Jiang
@contact: [email protected]
"""
from ..imagelist import ImageList
from ..office31 import Office31
from ..officehome import OfficeHome
from ..visda2017 import VisDA2017
from typing import Optional, ClassVar, Sequence
from copy import deepcopy
__all__ = ['Office31', 'OfficeHome', "VisDA2017"]
[docs]def open_set(dataset_class: ClassVar, public_classes: Sequence[str],
private_classes: Optional[Sequence[str]] = ()) -> ClassVar:
"""
Convert a dataset into its open-set version.
In other words, those samples which doesn't belong to `private_classes` will be marked as "unknown".
Be aware that `open_set` will change the label number of each category.
Args:
dataset_class (class): Dataset class. Only subclass of ``ImageList`` can be open-set.
public_classes (sequence[str]): A sequence of which categories need to be kept in the open-set dataset.\
Each element of `public_classes` must belong to the `classes` list of `dataset_class`.
private_classes (sequence[str], optional): A sequence of which categories need to be marked as "unknown" \
in the open-set dataset. Each element of `private_classes` must belong to the `classes` list of \
`dataset_class`. Default: ().
Examples::
>>> public_classes = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard']
>>> private_classes = ['laptop_computer', 'monitor', 'mouse', 'mug', 'projector']
>>> # create a open-set dataset class which has classes
>>> # 'back_pack', 'bike', 'calculator', 'headphones', 'keyboard' and 'unknown'.
>>> OpenSetOffice31 = open_set(Office31, public_classes, private_classes)
>>> # create an instance of the open-set dataset
>>> dataset = OpenSetDataset(root="data/office31", task="A")
"""
if not (issubclass(dataset_class, ImageList)):
raise Exception("Only subclass of ImageList can be openset")
class OpenSetDataset(dataset_class):
def __init__(self, **kwargs):
super(OpenSetDataset, self).__init__(**kwargs)
samples = []
all_classes = list(deepcopy(public_classes)) + ["unknown"]
for (path, label) in self.samples:
class_name = self.classes[label]
if class_name in public_classes:
samples.append((path, all_classes.index(class_name)))
elif class_name in private_classes:
samples.append((path, all_classes.index("unknown")))
self.samples = samples
self.classes = all_classes
self.class_to_idx = {cls: idx
for idx, cls in enumerate(self.classes)}
return OpenSetDataset
[docs]def default_open_set(dataset_class: ClassVar, source: bool) -> ClassVar:
"""
Default open-set used in some paper.
Args:
dataset_class (class): Dataset class. Currently, dataset_class must be one of
:class:`~common.vision.datasets.office31.Office31`, :class:`~common.vision.datasets.officehome.OfficeHome`,
:class:`~common.vision.datasets.visda2017.VisDA2017`,
source (bool): Whether the dataset is used for source domain or not.
"""
if dataset_class == Office31:
public_classes = Office31.CLASSES[:20]
if source:
private_classes = ()
else:
private_classes = Office31.CLASSES[20:]
elif dataset_class == OfficeHome:
public_classes = sorted(OfficeHome.CLASSES)[:25]
if source:
private_classes = ()
else:
private_classes = sorted(OfficeHome.CLASSES)[25:]
elif dataset_class == VisDA2017:
public_classes = ('bicycle', 'bus', 'car', 'motorcycle', 'train', 'truck')
if source:
private_classes = ()
else:
private_classes = ('aeroplane', 'horse', 'knife', 'person', 'plant', 'skateboard')
else:
raise NotImplementedError("Unknown openset domain adaptation dataset: {}".format(dataset_class.__name__))
return open_set(dataset_class, public_classes, private_classes)