PyTorch源码解析--torchvision.transforms(数据预处理、数据增强)
PyTorch框架中有一个很常用的包:torchvision
torchvision主要由3个子包构成:torchvision.datasets
、torchvision.models
、torchvision.transforms
详细内容可参考:http://pytorch.org/docs/master/torchvision/index.html
GitHub:https://github.com/pytorch/vision/tree/master/torchvision。
这篇主要介绍torchvision.transformas,基本上PyTorch中的resize、crop、normalize等常见的数据预处理及数据增强(data augmentation)操作都可以通过该接口实现。
torchvision.transformas主要涉及两个文件:transformas.py
和functional.py
,在transformas.py
中定义了各种data augmentation的类,在每个类中通过调用functional.py中对应的函数完成data augmentation操作。
$ vim /home/lwp/.local/lib/python2.7/site-packages/torchvision/transforms/transforms.py
使用示例,这是Re-ID MGN模型实现代码中的一段,https://github.com/lwplw/re-id_mgn/blob/master/pytorch_MGN/data/init.py,用到了Resize
、RandomHorizontalFlip
、ToTensor
、Normalize
:
from importlib import import_module
from torchvision import transforms
from utils.random_erasing import RandomErasing
from data.sampler import RandomSampler
from torch.utils.data import dataloader
class Data:
def __init__(self, args):
train_list = [
transforms.Resize((args.height, args.width), interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
if args.random_erasing:
train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0]))
train_transform = transforms.Compose(train_list)
test_transform = transforms.Compose([
transforms.Resize((args.height, args.width), interpolation=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
if not args.test_only:
module_train = import_module('data.' + args.data_train.lower())
self.trainset = getattr(module_train, args.data_train)(args, train_transform, 'train')
self.train_loader = dataloader.DataLoader(self.trainset,
sampler=RandomSampler(self.trainset,args.batchid,batch_image=args.batchimage),
#shuffle=True,
batch_size=args.batchid * args.batchimage,
num_workers=args.nThread)
else:
self.train_loader = None
if args.data_test in ['Market1501']:
module = import_module('data.' + args.data_train.lower())
self.testset = getattr(module, args.data_test)(args, test_transform, 'test')
self.queryset = getattr(module, args.data_test)(args, test_transform, 'query')
else:
raise Exception()
self.test_loader = dataloader.DataLoader(self.testset, batch_size=args.batchtest, num_workers=args.nThread)
self.query_loader = dataloader.DataLoader(self.queryset, batch_size=args.batchtest, num_workers=args.nThread)
各种操作的类定义在transformas.py
文件中:
from.import functional as F
,导入了functional.py
中具体的data augmentation函数;__all__
列表定义了可以从外部import的函数名或类名。
from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps, ImageEnhance
try:
import accimage
except ImportError:
accimage = None
import numpy as np
import numbers
import types
import collections
import warnings
from . import functional as F
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
}
Compose()
用来管理各个transform,其中__call__
方法就是对输入img遍历所有的transform操作。
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
ToTensor()
Convert a PIL Image
or numpy.ndarray
to tensor.
在做数据归一化之前必须要把PIL Image
转成Tensor
,其它resize或crop操作不需要。
class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
ToPILImage()
Convert a tensor
or an ndarray
to PIL Image
.
ToTensor()
的反向操作。
Normalize()
数据归一化处理,调用前数据需处理成Tensor
。
class Normalize(object):
"""Normalize a tensor image with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
will normalize each channel of the input ``torch.*Tensor`` i.e.
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
Resize()
对PIL Image
实现resize
操作。
- 如果输入为单个
int
值,则将输入图像的短边resize到这个int数,长边根据对应比例调整,图像长宽比保持不变。 - 如果输入为
(h,w)
,且h、w为int,则直接将输入图像resize到(h,w)尺寸,图像的长宽比可能会发生变化
在__call__
方法中调用了functional.py
脚本中的resize函数来完成resize操作。因为输入是PIL Image
,所以resize函数基本是在调用Image的各种方法。
class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image.
"""
return F.resize(img, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
CenterCrop()
以输入图像img的中心作为中心点进行指定size的crop操作,在数据增强中一版不会去使用该方法。因为当size固定时,对于同一张img,N次CenterCrop的结果是一样的。
size可以给单个int
值,也可以给(int(size), int(size))
class CenterCrop(object):
"""Crops the given PIL Image at the center.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
return F.center_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
RandomCrop()
RandomCrop相比前面的CenterCrop要更加常用一些,两者的区别在于RandomCrop在crop时的中心点坐标是随机的,不再是输入图像的中心坐标,因此基本上每次crop生成的图像都是不同的。
class RandomCrop(object):
"""Crop the given PIL Image at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is 0, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception.
"""
def __init__(self, size, padding=0, pad_if_needed=False):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
if self.padding > 0:
img = F.pad(img, self.padding)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
i, j, h, w = self.get_params(img, self.size)
return F.crop(img, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
RandomHorizontalFlip()
图像随机水平翻转,翻转概率为0.5
。
class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL Image randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
return F.hflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
RandomVerticalFlip()
图像随机垂直翻转
class RandomVerticalFlip(object):
"""Vertically flip the given PIL Image randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
return F.vflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
RandomResizedCrop()
CenterCrop
和RandomCrop
在crop时是固定size
,RandomResizedCrop
则是random size
的crop。
该类源码需要3个参数:size
、scale
和ratio
,这里我在使用中将接口中size
修改成了size_h, size_w
。方法为先crop,再resize到指定尺寸。
crop时,其中心点坐标和宽高是由get_params
方法得到的,首先在scale
限定的数值范围内随机生成一个数,用这个数乘以输入图像的面积作为crop后图像的面积,然后在ratio
限定的数值范围内随机生成一个数,表示宽高比,根据这两个值就可以得到crop图像的宽高。crop图像的中心点坐标,是类RandomCrop类一样是随机生成的。
class RandomResizedCrop(object):
"""Crop the given PIL Image to random size and aspect ratio.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size_h, size_w, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
self.size = (size_h, size_w)
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(*scale) * area
aspect_ratio = random.uniform(*ratio)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
j = (img.size[0] - w) // 2
return i, j, w, w
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
更多推荐








所有评论(0)