Transforms

torchvision | pytorch doc{target="_blank"}

Transforming and augmenting images | pytorch doc{target="_blank"}

Demo

使用 PIL 来读取图片得到的是 PIL Image 类型的数据,需要通过一些方法将 PIL Image 转换成 torch tensor 数据。此外,通过 PILToImage()ToImage()ToTensor() 这些方法转换 ann 单通道图像得到的 tensor 的 shape 会是 [1, H, W] 而不是需要的 [H, W]

图像分割任务中,使用 CrossEntropyLoss() 需要输入的 target 数据为单通道的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torchvision.transforms import v2
from PIL import Image

path2img = "path/to/img"
path2ann = "path/to/ann"
img = Image.open(path2img).convert('RGB')
ann = Image.open(path2ann).convert('L')

transform = v2.Compose([
# PILToTensor() convert a PIL Image with shape [H, W, C] and type uint8
# to a torch tensor with shape [C, H, W] and type torch.uint8
v2.PILToTensor(),
# ToDtype() convert a uint8 tensor in the range [0, 255] to
# a float32 tensor in the range [0.0, 1.0]
v2.ToDtype(torch.float32, scale=True),
# Normalize() normalize a tensor by using the equation below:
# output[channel] = (input[channel] - mean[channel]) / std[channel]
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

'''apply transform'''
img_trans, ann_trans = transform(img, ann)

'''ToDtype() and Normalize() will not be applied on ann,
only PILToTensor() applys on ann.'''
print(torch.equal(v2.PILToTensor()(ann), ann_trans)) # True

'''squeeze the annatation to [H, W] from [1, H, W]
and convert it to long type from uint8'''
ann = ann.squeeze(0).to(torch.long)

使用 cv2 来读取图片,得到是 numpy array 类型,比较方便转换成 torch tensor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torchvision.transforms import v2
import torch
import cv2

path2img = "path/to/img"
path2ann = "path/to/ann"

img = cv2.imread(path2img, cv2.IMREAD_COLOR_RGB)
ann = cv2.imread(path2ann, cv2.IMREAD_GRAYSCALE)

'''convert numpy array to torch tensor'''
img = torch.from_numpy(img).permute(2, 0, 1)
'''convert ann dtype to long(int64) from uint8'''
ann = torch.from_numpy(ann).to(torch.long)

transform = v2.Compose([
# ToDtype() convert a uint8 tensor in the range [0, 255] to
# a float32 tensor in the range [0.0, 1.0]
v2.ToDtype(torch.float32, scale=True),
# Normalize() normalize a tensor by using the equation below:
# output[channel] = (input[channel] - mean[channel]) / std[channel]
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

'''apply transform'''
img_trans, ann_trans = transform(img, ann)
print(type(img_trans), img_trans.shape, img_trans.dtype)
# <class 'torch.Tensor'> torch.Size([3, 512, 512]) torch.float32
print(type(ann_trans), ann_trans.shape, ann_trans.dtype)
# <class 'torch.Tensor'> torch.Size([512, 512]) torch.int64

PIL To Tensor

PIL.Image.Image 类型的图片转换成 Tensor 类型有两种方法。一种是使用 v2.ToTensor() ,另一种是使用 v2.PILToTensor()。两者的区别在于 v2.ToTensor() 会对图片进行一次 scale,将原本的范围在 [0, 255] 的 uint8 类型的数据转换成范围在 [0, 1] 的 torch.float32 类型的数据;而 v2.PILToTensor() 则不会进行 scale,输出的 tensor 数值范围依然在 [0, 255] 数据类型为 torch.uint8。

1
2
3
4
5
6
7
8
9
10
11
12
13
from PIL import Image
from torchvision.transforms import v2
import torch

path2img = "path/to/img"
img = Image.open(path2img).convert('RGB')

img_totensor = v2.ToTensor()(img) # shape: [3, H, W]
img_piltotensor = v2.PILToTensor()(img) # shape: [3, H, W]
print(img_totensor.dtype) # torch.float32
print(img_piltotensor.dtype) # torch.uint8

print(torch.equal(img_totensor, img_piltotensor/255)) # True

ToTensor() is deprecated

在新版的 PyTorch 中使用 ToTensor() 会受到如下 Warning,用于警告 ToTensor() 方法在未来的版本中会被弃用,建议用户将其换成 v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])

UserWarning: The transform ToTensor() is deprecated and will be removed in a future release. Instead, please use v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]).Output is equivalent up to float precision.

但是,v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) 在使用的时候有一点奇怪,它与 ToTensor() 方法的结果并不 equal,两者在数值上存在非常微小的差距:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from PIL import Image
from torchvision.transforms import v2
import torch

path2img = "path/to/img"
img = Image.open(path2img).convert('RGB')

ToTensor = v2.Compose([
v2.ToTensor(),
])

ToDtype = v2.Compose([
v2.ToImage(),
# v2.PILToTensor(),
v2.ToDtype(torch.float32, scale=True),
])

img_totensor = ToTensor(img)
img_todtype = ToDtype(img)

print(type(img_totensor), img_totensor.dtype)
# <class 'torch.Tensor'> torch.float32
print(type(img_todtype), img_todtype.dtype)
# <class 'torchvision.tv_tensors._image.Image'> torch.float32

print(torch.equal(img_totensor, img_todtype)) # False
print(img_totensor == img_todtype)
# tensor([[[False, False, False, ..., True, True, True],
# [False, False, False, ..., True, True, False],
# [ True, False, False, ..., True, True, False],
# ...,
# [ True, True, True, ..., False, False, False]]])
print(img_totensor - img_todtype)
# tensor([[[-1.4901e-08, -1.4901e-08, -1.4901e-08, ..., 0.0000e+00, 0.0000e+00],
# [-1.4901e-08, -1.4901e-08, -1.4901e-08, ..., 0.0000e+00, -1.4901e-08],
# [ 0.0000e+00, -1.4901e-08, -1.4901e-08, ..., 0.0000e+00, -1.4901e-08],
# ...,
# [ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -2.9802e-08, -2.9802e-08]]])

我没有检查 ToDtype() 的源代码,但我猜这可能是从 float64 到 float32 的类型转换过程中的精度损失导致的微小的差别。可能是 ToDtype() 中 scale 操作结果的 dtype 是float64,因此在将 float64 转换为 float32 的过程中存在精度损失。这个问题在这个 issue 中 Update torchvision transforms -> transforms.v2 #701 有讨论。

albumentations

使用 albumentations 也可以达到同样的效果,不过需要一点复杂的类型转换。需要先将 PIL.Image.Image 类型图片转换成 numpy.ndarray 再使用 albumentations 的 transform,最后转换成 torch.tensor 进行后续的工作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy
import torch
import albumentations as A
from PIL import Image

path2img = "path/to/img"
path2ann = "path/to/ann"
img = Image.open(path2img).convert('RGB')
ann = Image.open(path2ann).convert('L')

transform = A.Compose([
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255)
])
# pil image to numpy array
img, ann = numpy.array(img), numpy.array(ann)
# apply transform
transformed = transform(image=img, mask=ann)
img_trans, ann_trans = transformed['image'], transformed['mask']
# ann have not be applied transform
print(numpy.equal(ann, ann_trans)) # [[ True ... True ]]
# numpy array to torch tensor
img_trans = torch.tensor(img_trans, dtype=torch.float32).permute(2, 0, 1)
ann_trans = torch.tensor(ann_trans, dtype=torch.long)