论文

UNet Structure

转置卷积

移动窗口的卷积运算可以转换成矩阵乘法, 将输入图片中的像素按从左到右从上到下的顺序排列成一个列向量, 卷积核的每个窗口都可以同样排成一个行向量, 整个卷积核对应着一个矩阵, 其行数是卷积窗口的数量, 列数是输入图片的像素数.
例如, 对于 4×44\times4 的输入图片和 3×33\times3 的卷积核, 做 padding=0, strides=1 的卷积运算, 得到一个 2×22\times2 的输出. 将这个过程用矩阵乘法表示如下:

[y00y01y10y11]=[w0,0w0,1w0,20w1,0w1,1w1,20w2,0w2,1w2,2000000w0,0w0,1w0,20w1,0w1,1w1,20w2,0w2,1w2,200000000w0,0w0,1w0,20w1,0w1,1w1,20w2,0w2,1w2,2000000w0,0w0,1w0,20w1,0w1,1w1,20w2,0w2,1w2,2][x00x01x02x03x30x31x32x33]\scriptsize{ \begin{bmatrix} y_{00} \\ y_{01} \\ y_{10} \\ y_{11} \end{bmatrix} = \left[\begin{array}{cccc|cccc|cccc|cccc} w_{0,0} & w_{0,1} & w_{0,2} & 0 & w_{1,0} & w_{1,1} & w_{1,2} & 0 & w_{2,0} & w_{2,1} & w_{2,2} & 0 & 0 & 0 & 0 & 0 \\ 0 & w_{0,0} & w_{0,1} & w_{0,2} & 0 & w_{1,0} & w_{1,1} & w_{1,2} & 0 & w_{2,0} & w_{2,1} & w_{2,2} & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & w_{0,0} & w_{0,1} & w_{0,2} & 0 & w_{1,0} & w_{1,1} & w_{1,2} & 0 & w_{2,0} & w_{2,1} & w_{2,2} & 0 \\ 0 & 0 & 0 & 0 & 0 & w_{0,0} & w_{0,1} & w_{0,2} & 0 & w_{1,0} & w_{1,1} & w_{1,2} & 0 & w_{2,0} & w_{2,1} & w_{2,2} \end{array}\right] \begin{bmatrix} x_{00} \\ x_{01} \\ x_{02} \\ x_{03} \\ \vdots \\ x_{30} \\ x_{31} \\ x_{32} \\ x_{33} \end{bmatrix} }

上面矩阵中的每一行都对应着一个卷积的窗口, 它们和输入图片做 element-wise 的乘积再求和, 即得到对应窗口位置的卷积输出, 如下示意:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
\scriptsize
\begin{bmatrix}
\textcolor{red}{y_{00} }\\ y_{01} \\ y_{10} \\ y_{11}
\end{bmatrix}
=
\begin{matrix}
\textcolor{red}{
\begin{bmatrix}
w_{0,0} & w_{0,1} & w_{0,2} & 0 \\
w_{1,0} & w_{1,1} & w_{1,2} & 0 \\
w_{2,0} & w_{2,1} & w_{2,2} & 0 \\
0 & 0 & 0 & 0
\end{bmatrix}}
\\
\begin{bmatrix}
0 & w_{0,0} & w_{0,1} & w_{0,2} \\
0 & w_{1,0} & w_{1,1} & w_{1,2} \\
0 & w_{2,0} & w_{2,1} & w_{2,2} \\
0 & 0 & 0 & 0
\end{bmatrix}
\\
\begin{bmatrix}
0 & 0 & 0 & 0 \\
w_{0,0} & w_{0,1} & w_{0,2} & 0 \\
w_{1,0} & w_{1,1} & w_{1,2} & 0 \\
w_{2,0} & w_{2,1} & w_{2,2} & 0
\end{bmatrix}
\\
\begin{bmatrix}
0 & 0 & 0 & 0 \\
0 & w_{0,0} & w_{0,1} & w_{0,2} \\
0 & w_{1,0} & w_{1,1} & w_{1,2} \\
0 & w_{2,0} & w_{2,1} & w_{2,2}
\end{bmatrix}
\end{matrix}
*
\textcolor{red}{
\begin{bmatrix}
x_{00} & x_{01} & x_{02} & x_{03} \\
x_{10} & x_{11} & x_{12} & x_{13} \\
x_{20} & x_{21} & x_{22} & x_{23} \\
x_{30} & x_{31} & x_{32} & x_{33}
\end{bmatrix}}

而转置卷积是在上述矩阵乘法形式的卷积的基础上, 将变换矩阵转置. 从而, 将原本 164×442×2\overset{4\times4}{16}\rightarrow \overset{2\times2}{4} 的线性变换, 变成了 42×2164×4\overset{2\times2}{4}\rightarrow \overset{4\times4}{16} 的线性变换. 如此可以对图片进行放大, 进行上采样.

[y00y01y02y03y30y31y32y33]=[w0,0000w0,1w0,000w0,2w0,1000w0,200w1,00w0,00w1,1w1,0w0,1w0,0w1,2w1,1w0,2w0,10w1,20w0,2w2,00w1,00w2,1w2,0w1,1w1,0w2,2w2,1w1,2w1,10w2,20w1,200w2,0000w2,1w2,000w2,2w2,1000w2,2][x00x01x10x11]\scriptsize{ \begin{bmatrix} \textcolor{red}{y'_{00}} \\ y'_{01} \\ y'_{02} \\ y'_{03} \\ \vdots \\ y'_{30} \\ y'_{31} \\ y'_{32} \\ y'_{33} \end{bmatrix} = \left[\begin{array}{cccc} \textcolor{red}{w_{0,0}} & \textcolor{red}{0} & \textcolor{red}{0} & \textcolor{red}{0} \\ w_{0,1} & w_{0,0} & 0 & 0 \\ w_{0,2} & w_{0,1} & 0 & 0 \\ 0 & w_{0,2} & 0 & 0 \\ \hline w_{1,0} & 0 & w_{0,0} & 0 \\ w_{1,1} & w_{1,0} & w_{0,1} & w_{0,0} \\ w_{1,2} & w_{1,1} & w_{0,2} & w_{0,1} \\ 0 & w_{1,2} & 0 & w_{0,2}\\ \hline w_{2,0} & 0 & w_{1,0} & 0 \\ w_{2,1} & w_{2,0} & w_{1,1} & w_{1,0} \\ w_{2,2} & w_{2,1} & w_{1,2} & w_{1,1} \\ 0 & w_{2,2} & 0 & w_{1,2} \\ \hline 0 & 0 & w_{2,0} & 0 \\ 0 & 0 & w_{2,1} & w_{2,0} \\ 0 & 0 & w_{2,2} & w_{2,1} \\ 0 & 0 & 0 & w_{2,2} \end{array}\right] \textcolor{red}{ \begin{bmatrix} x'_{00} \\ x'_{01} \\ x'_{10} \\ x'_{11} \end{bmatrix}} }

转置矩阵中的每一列都对应着一个卷积的窗口, 上面的过程对应着, 卷积窗口和输入图片作乘积再求和, 如下示意:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
\tiny
\begin{split}
\begin{bmatrix}
y'_{00} & y'_{01} & y'_{02} & y'_{03}\\
y'_{10} & y'_{11} & y'_{12} & y'_{13}\\
y'_{20} & y'_{21} & y'_{22} & y'_{23}\\
y'_{30} & y'_{31} & y'_{32} & y'_{33}
\end{bmatrix}
&=
\left\{
\textcolor{red}{
\begin{bmatrix}
w_{0,0} & w_{0,1} & w_{0,2} & 0 \\
w_{1,0} & w_{1,1} & w_{1,2} & 0 \\
w_{2,0} & w_{2,1} & w_{2,2} & 0 \\
0 & 0 & 0 & 0
\end{bmatrix}},
\begin{bmatrix}
0 & w_{0,0} & w_{0,1} & w_{0,2} \\
0 & w_{1,0} & w_{1,1} & w_{1,2} \\
0 & w_{2,0} & w_{2,1} & w_{2,2} \\
0 & 0 & 0 & 0
\end{bmatrix},
\begin{bmatrix}
0 & 0 & 0 & 0 \\
w_{0,0} & w_{0,1} & w_{0,2} & 0 \\
w_{1,0} & w_{1,1} & w_{1,2} & 0 \\
w_{2,0} & w_{2,1} & w_{2,2} & 0
\end{bmatrix},
\begin{bmatrix}
0 & 0 & 0 & 0 \\
0 & w_{0,0} & w_{0,1} & w_{0,2} \\
0 & w_{1,0} & w_{1,1} & w_{1,2} \\
0 & w_{2,0} & w_{2,1} & w_{2,2}
\end{bmatrix}
\right\}
\begin{bmatrix}
\textcolor{red}{x'_{00}}\\ x'_{01}\\ x'_{10}\\ x'_{11}
\end{bmatrix}\\
&=
\begin{bmatrix}
w_{0,0} & w_{0,1} & w_{0,2} & 0 \\
w_{1,0} & w_{1,1} & w_{1,2} & 0 \\
w_{2,0} & w_{2,1} & w_{2,2} & 0 \\
0 & 0 & 0 & 0
\end{bmatrix}x'_{00}+
\begin{bmatrix}
0 & w_{0,0} & w_{0,1} & w_{0,2} \\
0 & w_{1,0} & w_{1,1} & w_{1,2} \\
0 & w_{2,0} & w_{2,1} & w_{2,2} \\
0 & 0 & 0 & 0
\end{bmatrix}x'_{01}+
\begin{bmatrix}
0 & 0 & 0 & 0 \\
w_{0,0} & w_{0,1} & w_{0,2} & 0 \\
w_{1,0} & w_{1,1} & w_{1,2} & 0 \\
w_{2,0} & w_{2,1} & w_{2,2} & 0
\end{bmatrix}x'_{10}+
\begin{bmatrix}
0 & 0 & 0 & 0 \\
0 & w_{0,0} & w_{0,1} & w_{0,2} \\
0 & w_{1,0} & w_{1,1} & w_{1,2} \\
0 & w_{2,0} & w_{2,1} & w_{2,2}
\end{bmatrix}x'_{11}\\
&=
\left\{
\begin{matrix}
\textcolor{red}{
\begin{bmatrix}
w_{0,0} & w_{0,1} & w_{0,2} & 0 \\
w_{1,0} & w_{1,1} & w_{1,2} & 0 \\
w_{2,0} & w_{2,1} & w_{2,2} & 0 \\
0 & 0 & 0 & 0
\end{bmatrix}}&
\begin{bmatrix}
0 & w_{0,0} & w_{0,1} & w_{0,2} \\
0 & w_{1,0} & w_{1,1} & w_{1,2} \\
0 & w_{2,0} & w_{2,1} & w_{2,2} \\
0 & 0 & 0 & 0
\end{bmatrix}\\
\begin{bmatrix}
0 & 0 & 0 & 0 \\
w_{0,0} & w_{0,1} & w_{0,2} & 0 \\
w_{1,0} & w_{1,1} & w_{1,2} & 0 \\
w_{2,0} & w_{2,1} & w_{2,2} & 0
\end{bmatrix}&
\begin{bmatrix}
0 & 0 & 0 & 0 \\
0 & w_{0,0} & w_{0,1} & w_{0,2} \\
0 & w_{1,0} & w_{1,1} & w_{1,2} \\
0 & w_{2,0} & w_{2,1} & w_{2,2}
\end{bmatrix}
\end{matrix}
\right\}*
\begin{bmatrix}
\textcolor{red}{x'_{00}} & x'_{01} \\ x'_{10} & x'_{11}
\end{bmatrix}
\end{split}

这样相当于 [x00x01x10x11]\left[\begin{smallmatrix}x'_{00}&x'_{01}\\x'_{10}&x'_{11}\end{smallmatrix}\right] 作为权重,按照下面的方式分布在 4×44\times4 的输出上

1
2
3
4
5
6
7
8
9
10
\scriptsize
\left[
\begin{array}{ccc}
x_{00} & \fbox{$\begin{matrix} x_{00}+x_{01} \end{matrix}$} & x_{01}\\\\
\fbox{$\begin{matrix} x_{00}\\+\\x_{10} \end{matrix}$} &
\fbox{$\begin{matrix} x_{00}+x_{01}\\+\\x_{10}+x_{11} \end{matrix}$} &
\fbox{$\begin{matrix} x_{01}\\+\\x_{11} \end{matrix}$}\\\\
x_{10} & \fbox{$\begin{matrix} x_{10}+x_{11} \end{matrix}$} & x_{11}\\
\end{array}
\right]

这个结果相当于用原卷积核左右镜像+上下镜像后的矩阵作为卷积核, 对两层 zero padding 的输入图像作卷积

w2,2w2,1w2,0w1,2w1,1w1,0w0,2w0,1w0,0[00000000000000x00x010000x10x1100000000000000]\fbox{$\begin{matrix} w_{2,2} & w_{2,1} & w_{2,0} \\ w_{1,2} & w_{1,1} & w_{1,0} \\ w_{0,2} & w_{0,1} & w_{0,0} \end{matrix}$} \begin{bmatrix} 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & x'_{00} & x'_{01} & 0 & 0 \\ 0 & 0 & x'_{10} & x'_{11} & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 \\ \end{bmatrix}

关于转置卷积是否会还原被卷积的图像

注意转置卷积并不能还原经过卷积操作后的矩阵(图像),它还原的是矩阵的维度,相当于重新将卷积后的图像从低维映射到高维。

y=AxA is invertibleA1y=x,ATyxy = Ax \xLeftrightarrow{\text{A is invertible}} A^{-1}y = x,\quad A^T y \neq x

借助广义逆来观察这一结果:已知 Y=AXY = AX,求是否存在变换矩阵 BB 使得 YB=XYB = X

以下是 BB 所有的解

B=YgX+[IYgY]wB = Y^gX + [I - Y^gY]w

其中,YgY^gYY 的任意一个广义逆矩阵,ww 为任意矩阵。

存在解的条件是当且仅当 YgXY^gX 为其中一个解, 也就是当且仅当 YYgX=XYY^gX = X

是否存在还原 Y 回 X 的变换?

对于 Y=AXY = AX, 是否存在 BB 使得 BY=BAX=XBY = BAX = X

[x1xn]AX[y1ym]\begin{bmatrix}x_1 \\ \vdots \\ x_n\end{bmatrix} \xrightarrow{AX} \begin{bmatrix}y_1 \\ \vdots \\ y_m\end{bmatrix}

[x1xn][x1y100xny100]n×m[y1ym]\begin{bmatrix}x_1 \\ \vdots \\ x_n\end{bmatrix} \leftarrow \begin{bmatrix} \frac{x_1}{y_1} & 0 & \cdots & 0\\ \vdots & \vdots & \ddots & \vdots \\ \frac{x_n}{y_1} & 0 & \cdots & 0 \end{bmatrix}_{n\times m} \begin{bmatrix}y_1 \\ \vdots \\ y_m\end{bmatrix}

ConvTranspose2d

torch.nn.ConvTranspose2d{target="_blank"}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch

#随机生成一个矩阵(图片)
input = torch.randn(1, #batch_size
32, #in_channels
28, #width
28) #height

#定义一个转置卷积
# 4.5 No zero padding, non-unit strides, transposed
# k=2, s=2, p=0 => o'=s(i'-1)+k=2i'
# 输出图片的 size 是输入的 2 倍 28x28→56x56
tranposed_conv = torch.nn.ConvTranspose2d(
in_channels=32, out_channels=16,
kernel_size=2, stride=2
)

#用上述定义的转置卷积对 input 进行卷积操作
output = tranposed_conv(input)

print(input.shape) #torch.Size([1, 32, 28, 28])
print(output.shape) #torch.Size([1, 16, 56, 56])
print(tranposed_conv.weight.shape) #torch.Size([32, 16, 2, 2])

Pytorch 实现

PyTorch Image Segmentation Tutorial with U-NET: everything from scratch baby | Aladdin Persson | YouTube

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from torch import nn
from torchvision.transforms import functional as F

class DualConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DualConv, self).__init__()
self.dualconv = nn.Sequential(
# 1st convolution: in_channels ➡ out_channels
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
# 2nd Convolution: out_channels ➡ out_channels
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)

def forward(self, x):
return self.dualconv(x)

class UNet(nn.Module):
# Note that it is 1 for the in_channels' value in the paper, however, it is 3 in our case.
# The reason is that the pictures used in the paper are grey-scale map, ours are colored images.
# Besides, the out_channels is 1 instead of 2 in the paper.
# It is some sort of the same about 1 and 2 in the case of binary output.
def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
super(UNet, self).__init__()
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 161 x 161, output: 160 x 160

# Down part of UNet
# 3->64->64; 64->128->128; 128->256->256; 256->512->512.
for feature in features:
self.downs.append(DualConv(in_channels, feature))
in_channels = feature

# Up part of UNet
for feature in reversed(features):
self.ups.append(
# 4.5 No zero padding, non-unit strides, transposed
# k=2, s=2, p=0 => o'=s(i'-1)+k=2i'
nn.ConvTranspose2d(in_channels=feature*2,
out_channels=feature,
kernel_size=2, stride=2)
)
self.ups.append(DualConv(feature*2, feature))

# 512->1024
self.bottleneck = DualConv(features[-1], features[-1]*2)
# conv 1x1
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

print(f"{len(self.ups) = }")

def forward(self, x):
skip_connections = []

for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)

x = self.bottleneck(x)
# reverse the entrys in skip_connections
skip_connections = skip_connections[::-1] #[-1:-len(a)-1:-1]

for idx in range(0, len(self.ups), 2):
# len(self.up) = 8
# idx: 0, 2, 4, 6
# idx//2: 0, 1, 2, 3

# up-conv 2x2
x = self.ups[idx](x)
skip_connection = skip_connections[idx//2]

if x.shape != skip_connection.shape:
x = F.resize(x, size=skip_connection.shape[2:], antialias=True)

concat_skip = torch.cat((skip_connection, x), dim=1)
# Rightarrow conv 3x3, ReLU
x = self.ups[idx+1](concat_skip)

return self.final_conv(x)

if __name__ == "__main__":
# batch, channel, height, width
# B, C, H, W = 8, 3, 161, 161
B, C, H, W = 8, 3, 512, 512
class_num = 6
x = torch.randn(size=[B, C, H, W], dtype=torch.float32)
print(x.shape, x.dtype)
model = UNet(in_channels=C, out_channels=class_num)
preds = model(x)
print(preds.shape, preds.dtype)

assert preds.shape == torch.Size([B, class_num, H, W]), "something wrong!"

What dose Encoder do?

For a given image of size 1x572x572, U-Net uses 64 convolutional kernals to extract 64 feature maps from it and get 64x570x570 (original paper) or 64x572x572 (padding=1). And extract 64 feature maps again.

What dose Decoder do?

Convolution and Full Connected Layer

Convolution can be seen as a special kind of FCL, which selectively connects the neurons between the front and back layers. Besides, each output neuron shares a set of parameters/weights (represented by different colored lines) of the same kernel.

1×4×4(kernel_size=2, stride=1, padding=0)Conv2d(in_channels=1, out_channels=1)1×3×3(kernel_size=2, stride=1, padding=0)Conv2d(in_channels=1, out_channels=1)1×2×21\times4\times4 \xrightarrow[\text{(kernel\_size=2, stride=1, padding=0)}]{\text{Conv2d(in\_channels=1, out\_channels=1)}} 1\times3\times3 \xrightarrow[\text{(kernel\_size=2, stride=1, padding=0)}]{\text{Conv2d(in\_channels=1, out\_channels=1)}} 1\times2\times2

Convolution-Illustration

When you increase the output channels of the convolution, you are actually increasing the neurons in the output layer.

1×4×4(kernel_size=2, stride=1, padding=0)Conv2d(in_channels=1, out_channels=3)3×3×3(kernel_size=2, stride=1, padding=0)Conv2d(in_channels=3, out_channels=1)1×2×21\times4\times4 \xrightarrow[\text{(kernel\_size=2, stride=1, padding=0)}]{\text{Conv2d(in\_channels=1, out\_channels=3)}} 3\times3\times3 \xrightarrow[\text{(kernel\_size=2, stride=1, padding=0)}]{\text{Conv2d(in\_channels=3, out\_channels=1)}} 1\times2\times2

Convolution-Illustration-Multi-Channels

Transposed convolution dose the similar but reversed thing as convolution. It also selectively connects the neurons and shares parameters of kernel.

卷积:每个输出neuron共享同一个kernel的一组参数(权重)
转置卷积:每个输入neuron共享同一个kernel的一组参数(权重)

1×2×2(kernel_size=2, stride=1, padding=0)ConvTranspose2d(in_channels=1, out_channels=1)1×3×3(kernel_size=2, stride=1, padding=0)Conv2d(in_channels=1, out_channels=1)1×4×41\times2\times2 \xrightarrow[\text{(kernel\_size=2, stride=1, padding=0)}]{\text{ConvTranspose2d(in\_channels=1, out\_channels=1)}} 1\times3\times3 \xrightarrow[\text{(kernel\_size=2, stride=1, padding=0)}]{\text{Conv2d(in\_channels=1, out\_channels=1)}} 1\times4\times4

Transposed-Convolution-Illustration