Shortcuts

Source code for common.vision.transforms.segmentation

"""
@author: Junguang Jiang
@contact: [email protected]
"""
from PIL import Image
import random
import math
from typing import ClassVar, Sequence, List, Tuple
from torch import Tensor
import torch
import torchvision.transforms.functional as F
import torchvision.transforms.transforms as T
import torch.nn as nn
from . import MultipleApply as MultipleApplyBase, NormalizeAndTranspose as NormalizeAndTransposeBase


[docs]def wrapper(transform: ClassVar): """ Wrap a transform for classification to a transform for segmentation. Note that the segmentation label will keep the same before and after wrapper. Args: transform (class, callable): transform for classification Returns: transform for segmentation """ class WrapperTransform(transform): def __call__(self, image, label): image = super().__call__(image) return image, label return WrapperTransform
ColorJitter = wrapper(T.ColorJitter) Normalize = wrapper(T.Normalize) ToTensor = wrapper(T.ToTensor) ToPILImage = wrapper(T.ToPILImage) MultipleApply = wrapper(MultipleApplyBase) NormalizeAndTranspose = wrapper(NormalizeAndTransposeBase)
[docs]class Compose: """Composes several transforms together. Args: transforms (list): list of transforms to compose. Example: >>> Compose([ >>> Resize((512, 512)), >>> RandomHorizontalFlip() >>> ]) """ def __init__(self, transforms): super(Compose, self).__init__() self.transforms = transforms def __call__(self, image, target): for t in self.transforms: image, target = t(image, target) return image, target
[docs]class Resize(nn.Module): """Resize the input image and the corresponding label to the given size. The image should be a PIL Image. Args: image_size (sequence): The requested image size in pixels, as a 2-tuple: (width, height). label_size (sequence, optional): The requested segmentation label size in pixels, as a 2-tuple: (width, height). The same as image_size if None. Default: None. """ def __init__(self, image_size, label_size=None): super(Resize, self).__init__() self.image_size = image_size if label_size is None: self.label_size = image_size else: self.label_size = label_size
[docs] def forward(self, image, label): """ Args: image: (PIL Image): Image to be scaled. label: (PIL Image): Segmentation label to be scaled. Returns: Rescaled image, rescaled segmentation label """ # resize image = image.resize(self.image_size, Image.BICUBIC) label = label.resize(self.label_size, Image.NEAREST) return image, label
[docs]class RandomCrop(nn.Module): """Crop the given image at a random location. The image can be a PIL Image Args: size (sequence): Desired output size of the crop. """ def __init__(self, size): super(RandomCrop, self).__init__() self.size = size
[docs] def forward(self, image, label): """ Args: image: (PIL Image): Image to be cropped. label: (PIL Image): Segmentation label to be cropped. Returns: Cropped image, cropped segmentation label. """ # random crop left = image.size[0] - self.size[0] upper = image.size[1] - self.size[1] left = random.randint(0, left-1) upper = random.randint(0, upper-1) right = left + self.size[0] lower = upper + self.size[1] image = image.crop((left, upper, right, lower)) label = label.crop((left, upper, right, lower)) return image, label
[docs]class RandomHorizontalFlip(nn.Module): """Horizontally flip the given PIL Image randomly with a given probability. Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): super(RandomHorizontalFlip, self).__init__() self.p = p
[docs] def forward(self, image, label): """ Args: image: (PIL Image): Image to be flipped. label: (PIL Image): Segmentation label to be flipped. Returns: Randomly flipped image, randomly flipped segmentation label. """ if random.random() < self.p: return F.hflip(image), F.hflip(label) return image, label
[docs]class RandomResizedCrop(T.RandomResizedCrop): """Crop the given image to random size and aspect ratio. The image can be a PIL Image. A crop of random size (default: of 0.5 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size. Args: size (int or sequence): expected output size of each edge. If size is an int instead of sequence like (h, w), a square output size ``(size, size)`` is made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). scale (tuple of float): range of size of the origin size cropped ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped. interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, scale=(0.5, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BICUBIC): super(RandomResizedCrop, self).__init__(size, scale, ratio, interpolation)
[docs] @staticmethod def get_params( img: Tensor, scale: List[float], ratio: List[float] ) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random sized crop. Args: img (PIL Image): Input image. scale (list): range of scale of the origin size cropped ratio (list): range of aspect ratio of the origin aspect ratio cropped Returns: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ width, height = F._get_image_size(img) area = height * width for _ in range(10): target_area = area * random.uniform(scale[0], scale[1]) log_ratio = torch.log(torch.tensor(ratio)) aspect_ratio = math.exp(random.uniform(log_ratio[0], log_ratio[1])) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: i = random.randint(0, height - h) j = random.randint(0, width - w) return i, j, h, w # Fallback to central crop in_ratio = float(width) / float(height) if in_ratio < min(ratio): w = width h = int(round(w / min(ratio))) elif in_ratio > max(ratio): h = height w = int(round(h * max(ratio))) else: # whole image w = width h = height i = (height - h) // 2 j = (width - w) // 2 return i, j, h, w
[docs] def forward(self, image, label): """ Args: image: (PIL Image): Image to be cropped and resized. label: (PIL Image): Segmentation label to be cropped and resized. Returns: Randomly cropped and resized image, randomly cropped and resized segmentation label. """ top, left, height, width = self.get_params(image, self.scale, self.ratio) image = image.crop((left, top, left + width, top + height)) image = image.resize(self.size, self.interpolation) label = label.crop((left, top, left + width, top + height)) label = label.resize(self.size, Image.NEAREST) return image, label
[docs]class RandomChoice(T.RandomTransforms): """Apply single transformation randomly picked from a list. """ def __call__(self, image, label): t = random.choice(self.transforms) return t(image, label)
[docs]class RandomApply(T.RandomTransforms): """Apply randomly a list of transformations with a given probability. Args: transforms (list or tuple or torch.nn.Module): list of transformations p (float): probability """ def __init__(self, transforms, p=0.5): super(RandomApply, self).__init__(transforms) self.p = p def __call__(self, image, label): if self.p < random.random(): return image for t in self.transforms: image, label = t(image, label) return image

Docs

Access comprehensive documentation for Transfer Learning Library

View Docs

Tutorials

Get started for Transfer Learning Library

Get Started