Source code for common.vision.transforms.segmentation
"""
@author: Junguang Jiang
@contact: [email protected]
"""
from PIL import Image
import random
import math
from typing import ClassVar, Sequence, List, Tuple
from torch import Tensor
import torch
import torchvision.transforms.functional as F
import torchvision.transforms.transforms as T
import torch.nn as nn
from . import MultipleApply as MultipleApplyBase, NormalizeAndTranspose as NormalizeAndTransposeBase
[docs]def wrapper(transform: ClassVar):
""" Wrap a transform for classification to a transform for segmentation.
Note that the segmentation label will keep the same before and after wrapper.
Args:
transform (class, callable): transform for classification
Returns:
transform for segmentation
"""
class WrapperTransform(transform):
def __call__(self, image, label):
image = super().__call__(image)
return image, label
return WrapperTransform
ColorJitter = wrapper(T.ColorJitter)
Normalize = wrapper(T.Normalize)
ToTensor = wrapper(T.ToTensor)
ToPILImage = wrapper(T.ToPILImage)
MultipleApply = wrapper(MultipleApplyBase)
NormalizeAndTranspose = wrapper(NormalizeAndTransposeBase)
[docs]class Compose:
"""Composes several transforms together.
Args:
transforms (list): list of transforms to compose.
Example:
>>> Compose([
>>> Resize((512, 512)),
>>> RandomHorizontalFlip()
>>> ])
"""
def __init__(self, transforms):
super(Compose, self).__init__()
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
[docs]class Resize(nn.Module):
"""Resize the input image and the corresponding label to the given size.
The image should be a PIL Image.
Args:
image_size (sequence): The requested image size in pixels, as a 2-tuple:
(width, height).
label_size (sequence, optional): The requested segmentation label size in pixels, as a 2-tuple:
(width, height). The same as image_size if None. Default: None.
"""
def __init__(self, image_size, label_size=None):
super(Resize, self).__init__()
self.image_size = image_size
if label_size is None:
self.label_size = image_size
else:
self.label_size = label_size
[docs] def forward(self, image, label):
"""
Args:
image: (PIL Image): Image to be scaled.
label: (PIL Image): Segmentation label to be scaled.
Returns:
Rescaled image, rescaled segmentation label
"""
# resize
image = image.resize(self.image_size, Image.BICUBIC)
label = label.resize(self.label_size, Image.NEAREST)
return image, label
[docs]class RandomCrop(nn.Module):
"""Crop the given image at a random location.
The image can be a PIL Image
Args:
size (sequence): Desired output size of the crop.
"""
def __init__(self, size):
super(RandomCrop, self).__init__()
self.size = size
[docs] def forward(self, image, label):
"""
Args:
image: (PIL Image): Image to be cropped.
label: (PIL Image): Segmentation label to be cropped.
Returns:
Cropped image, cropped segmentation label.
"""
# random crop
left = image.size[0] - self.size[0]
upper = image.size[1] - self.size[1]
left = random.randint(0, left-1)
upper = random.randint(0, upper-1)
right = left + self.size[0]
lower = upper + self.size[1]
image = image.crop((left, upper, right, lower))
label = label.crop((left, upper, right, lower))
return image, label
[docs]class RandomHorizontalFlip(nn.Module):
"""Horizontally flip the given PIL Image randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super(RandomHorizontalFlip, self).__init__()
self.p = p
[docs] def forward(self, image, label):
"""
Args:
image: (PIL Image): Image to be flipped.
label: (PIL Image): Segmentation label to be flipped.
Returns:
Randomly flipped image, randomly flipped segmentation label.
"""
if random.random() < self.p:
return F.hflip(image), F.hflip(label)
return image, label
[docs]class RandomResizedCrop(T.RandomResizedCrop):
"""Crop the given image to random size and aspect ratio.
The image can be a PIL Image.
A crop of random size (default: of 0.5 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
Args:
size (int or sequence): expected output size of each edge. If size is an
int instead of sequence like (h, w), a square output size ``(size, size)`` is
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
scale (tuple of float): range of size of the origin size cropped
ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, scale=(0.5, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BICUBIC):
super(RandomResizedCrop, self).__init__(size, scale, ratio, interpolation)
[docs] @staticmethod
def get_params(
img: Tensor, scale: List[float], ratio: List[float]
) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Input image.
scale (list): range of scale of the origin size cropped
ratio (list): range of aspect ratio of the origin aspect ratio cropped
Returns:
params (i, j, h, w) to be passed to ``crop`` for a random sized crop.
"""
width, height = F._get_image_size(img)
area = height * width
for _ in range(10):
target_area = area * random.uniform(scale[0], scale[1])
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = math.exp(random.uniform(log_ratio[0], log_ratio[1]))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(ratio):
w = width
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = height
w = int(round(h * max(ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
[docs] def forward(self, image, label):
"""
Args:
image: (PIL Image): Image to be cropped and resized.
label: (PIL Image): Segmentation label to be cropped and resized.
Returns:
Randomly cropped and resized image, randomly cropped and resized segmentation label.
"""
top, left, height, width = self.get_params(image, self.scale, self.ratio)
image = image.crop((left, top, left + width, top + height))
image = image.resize(self.size, self.interpolation)
label = label.crop((left, top, left + width, top + height))
label = label.resize(self.size, Image.NEAREST)
return image, label
[docs]class RandomChoice(T.RandomTransforms):
"""Apply single transformation randomly picked from a list.
"""
def __call__(self, image, label):
t = random.choice(self.transforms)
return t(image, label)
[docs]class RandomApply(T.RandomTransforms):
"""Apply randomly a list of transformations with a given probability.
Args:
transforms (list or tuple or torch.nn.Module): list of transformations
p (float): probability
"""
def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms)
self.p = p
def __call__(self, image, label):
if self.p < random.random():
return image
for t in self.transforms:
image, label = t(image, label)
return image