AutoGrad

forward

考虑一个简单的情形——复合函数的反向传播。例如考虑 f()=()2f(\cdot) = (\cdot)^2g()=e()g(\cdot) = e^{(\cdot)} 的复合函数

y=f(g(f(x)))=(ex2)2y = f(g(f(x))) = (e^{x^2})^2

如何通过反向传播计算 y/x\partial{y} / \partial{x}

首先我们考虑正向 forward 的过程,对于一个输入 x 经过复合函数后得到输出 y:

x1()2x2=a2e()ea=b3()2b2=yx \xrightarrow[1]{(\cdot)^2} x^2 = a \xrightarrow[2]{e^{(\cdot)}} e^{a} = b \xrightarrow[3]{(\cdot)^2} b^2 = y

在复合函数 forward 的过程中会产生很多临时的“中间结果”,我们按照上面公式中所示,将对应的中间结果设为 a 和 b。

在程序设计上,首先定义了一个 Variable 类用于封装数据,其内部实际参与运算的数据类型为 np.ndarray
并且设计了一个 Function 类作为 SquareExp “算子”的基类。在 Function 类中我们定义了一个 forward 方法,Function 类中的 forward 方法并没有实际进行定义,只是作为“接口”,可以认为它是虚的(类比于C++中的虚函数)。实际的 forward 定义是写在子类中的,如果子类中没有定义 forward 方法,那么当子类对象调用 forward 时会默认调用基类 Function 中的 forward 方法,所以我们在基类 forward 方法中加了一个 raise NotImplementedError("...") 用于提醒用户需要在子类(实际的算子类)中定义 forward 方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np

class Variable:
def __init__(self, data) -> None:
self.data = data
# self.grad = None

class Function:
def __call__(self, input: Variable) -> Variable:
# self.input = input # store its input for backward()
x = input.data
y = self.forward(x)
output = Variable(y)
return output

def forward(self, x: np.ndarray) -> np.ndarray:
"""forward: x -[f]-> y"""
raise NotImplementedError("forward() have not been implemented in baseclass, "\
"and it must be implemented in a subclass")

class Square(Function):
def forward(self, x: np.ndarray) -> np.ndarray:
return x**2

class Exp(Function):
def forward(self, x: np.ndarray) -> np.ndarray:
return np.exp(x)

if __name__ == '__main__':

data = np.array(0.5) # <class 'numpy.ndarray'>
x = Variable(data) # <class '__main__.Variable'>

f1 = Square()
g = Exp()
f2 = Square()

'''x -[f1]-> a -[g]-> b -[f2]-> y'''
a = f1(x) # x -[f1]-> a = x^2
b = g(a) # a -[g ]-> b = e^(x^2)
y = f2(b) # b -[f2]-> y = (e^(x^2))^2

print(y.data)

backward

在 forward 的基础上,根据链式法则可以推导出

yx=ax(ba(yb(yy)))\frac{\partial y}{\partial x} = \frac{\partial a}{\partial x} \cdot \left( \frac{\partial b}{\partial a} \cdot \left( \frac{\partial y}{\partial b} \cdot \left(\frac{\partial y}{\partial y}\right) \right) \right)

其中的 y/y=1\partial y / \partial y = 1,这是为了后面编程中统一表示的方便而加上的。将这个求导公式换成 backward 的表达形式

yx1f(x)×ya2g(a)×yb3f(b)×yy\frac{\partial y}{\partial x}\xleftarrow[1]{f'(x)\times} \frac{\partial y}{\partial a}\xleftarrow[2]{g'(a)\times} \frac{\partial y}{\partial b}\xleftarrow[3]{f'(b)\times} \frac{\partial y}{\partial y}

其中的箭头可以看作是一种函数——backward,其输入为 最终结果forward 输出 的偏导,输出为 最终结果forward 输入 的偏导。例如

  • 在 forward 3 中我们有

    b3fy=f(b)b\xrightarrow[3]{f}y = f(b)

    forward 3 的输入为 b,输出为 y,并且 y 为最终结果。那么对应 backward 3 的输入为 最终结果 y 到 forward 3 输出 y 的偏导 y/y=1\partial y / \partial y = 1. 在 backward 3 中,对其输入 y/y\partial y / \partial y 乘以一个 ffforward 3 输入bb 的偏导 f(t)tt=b\frac{\partial f(t)}{\partial t} |_{t=b}

    yb=f(b)yyf(b)×yy\frac{\partial y}{\partial b} = f'(b)\cdot \frac{\partial y}{\partial y} \xleftarrow{f'(b)\times} \frac{\partial y}{\partial y}

  • 在 forward 2 中我们有

    a2gb=g(a)a\xrightarrow[2]{g}b = g(a)

    forward 2 的输入为 a,输出为 b。那么对应 backward 2 的输入为 最终结果 y 到 forward 2 输出 b 的偏导 y/b\partial y / \partial b. 在 backward 2 中,其对输入 y/b\partial y / \partial b 乘以一个 ggforward 2 输入 aa 处的偏导 g(t)tt=a\frac{\partial g(t)}{\partial t} |_{t=a}

    ya=g(a)ybg(a)×yb\frac{\partial y}{\partial a} = g'(a)\cdot \frac{\partial y}{\partial b} \xleftarrow{g'(a)\times} \frac{\partial y}{\partial b}

  • 在 forward 1 中我们有

    x1fa=f(x)x\xrightarrow[1]{f}a = f(x)

    forward 1 的输入为 x,输出为 a。那么对应 backward 1 的输入为 最终结果 y 到 forward 1 输出 a 的偏导 y/a\partial y / \partial a. 在 backward 1 中,其对输入 y/a\partial y / \partial a 乘以一个 ffforward 1 输入 xx 处的偏导 f(t)tt=x\frac{\partial f(t)}{\partial t} |_{t=x}

    yx=f(x)yaf(x)×ya\frac{\partial y}{\partial x} = f'(x)\cdot \frac{\partial y}{\partial a} \xleftarrow{f'(x)\times} \frac{\partial y}{\partial a}

在程序设计上,首先为 Variable 类添加一个 grad 成员,用于存储最终结果到 Variable 对象的梯度。并且在基类 Function 中添加一个成员 self.input,用于存储在 forward 过程中每个算子的输入,目的是为后面进行 backward 时求算子在输入处的梯度。

额外需要注意的是,针对在 forward 过程中复用的算子,需要创建不同的对象。例如 Square 在 forward 过程中被使用了两次,分别创建了 f1f2 两个对象。

xfabfyyxf(x)×yaybf(b)×yy\begin{array}{cc} x \xrightarrow{f} a & b \xrightarrow{f} y\\ \frac{\partial y}{\partial x}\xleftarrow{f'(x)\times} \frac{\partial y}{\partial a} & \frac{\partial y}{\partial b}\xleftarrow{f'(b)\times}\frac{\partial y}{\partial y} \end{array}

f1f2 有相同点和不同点,相同点在于它们都是 Square 对象,它们执行同一段 forward 的计算逻辑(函数映射)。但是不同点在于 f1f2 具有不同的输入和输出,f1 的输入是 x 输出是 af2 的输入是 b 输出是 y,它们对应的“算子结点”在内部保存的输入值不一样,f1.input.data = xf2.input.data = b。这样在计算 backward 时,f1.backward 在内部需要计算 f'(x) 的值,而 f2.backward 在内部需要计算 f'(b) 的值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from typing import Any

class Variable:
def __init__(self, data) -> None:
self.data = data
self.grad = None

class Function:
def __call__(self, input: Variable) -> Variable:
self.input = input # store its input for backward()
x = input.data
y = self.forward(x)
output = Variable(y)
return output

def forward(self, x: np.ndarray) -> np.ndarray:
"""forward: x -[f]-> y"""
raise NotImplementedError("forward() have not been implemented in baseclass, "\
"and it must be implemented in a subclass")

def backward(self, gy: np.ndarray) -> np.ndarray: # gradent y
"""
forward: x -[ f (·) ]-> y
backward: ∂(.)/∂x <-[ f'(x) ]- ∂(.)/∂y
\_∂f/∂x *_/
"""
raise NotImplementedError("backward() have not been implemented in baseclass, "\
"and it must be implemented in a subclass")

class Square(Function):
def forward(self, x: np.ndarray) -> np.ndarray:
return x**2

def backward(self, gy: np.ndarray) -> np.ndarray:
x = self.input.data
gx = 2 * x * gy
return gx

class Exp(Function):
def forward(self, x: np.ndarray) -> np.ndarray:
return np.exp(x)

def backward(self, gy: np.ndarray) -> np.ndarray:
x = self.input.data
gx = np.exp(x) * gy
return gx

if __name__ == '__main__':

data = np.array(0.5) # <class 'numpy.ndarray'>
x = Variable(data) # <class '__main__.Variable'>

f1 = Square()
g = Exp()
f2 = Square()

'''x -[f1]-> a -[g]-> b -[f2]-> y'''
a = f1(x) # x -[f1]-> a = x^2
b = g(a) # a -[g ]-> b = e^(x^2)
y = f2(b) # b -[f2]-> y = (e^(x^2))^2

'''
∂y/∂x = ∂a/∂x * (∂b/∂a * (∂y/∂b * (∂y/∂y)))
∂y/∂x <-[f1'(x)]- ∂y/∂a <-[g'(a)]- ∂y/∂b <-[f2'(b)]- ∂y/∂y
\_∂a/∂x*_/ \_∂b/∂a*_/ \_∂y/∂b*_/
'''
y.grad = np.array(1.0) # ∂y/∂y
b.grad = f1.backward(y.grad) # ∂y/∂b
a.grad = g.backward(b.grad) # ∂y/∂a
x.grad = f2.backward(a.grad) # ∂y/∂x
print(x.grad)