Shortcuts

Source code for dalib.translation.cyclegan.transform

"""
@author: Junguang Jiang
@contact: [email protected]
"""
import torch
import torch.nn as nn
import torchvision.transforms as T

from common.vision.transforms import Denormalize


[docs]class Translation(nn.Module): """ Image Translation Transform Module Args: generator (torch.nn.Module): An image generator, e.g. :meth:`~dalib.translation.cyclegan.resnet_9_generator` device (torch.device): device to put the generator. Default: 'cpu' mean (tuple): the normalized mean for image std (tuple): the normalized std for image Input: - image (PIL.Image): raw image in shape H x W x C Output: raw image in shape H x W x 3 """ def __init__(self, generator, device=torch.device("cpu"), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): super(Translation, self).__init__() self.generator = generator.to(device) self.device = device self.pre_process = T.Compose([ T.ToTensor(), T.Normalize(mean, std) ]) self.post_process = T.Compose([ Denormalize(mean, std), T.ToPILImage() ]) def forward(self, image): image = self.pre_process(image.copy()) # C x H x W image = image.to(self.device) generated_image = self.generator(image.unsqueeze(dim=0)).squeeze(dim=0).cpu() return self.post_process(generated_image)

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started