关于反向传播的实现 | 字数总计: 2.1k | 阅读时长: 10分钟 | 阅读量:
AutoGrad
forward
考虑一个简单的情形——复合函数的反向传播。例如考虑 f ( ⋅ ) = ( ⋅ ) 2 f(\cdot) = (\cdot)^2 f ( ⋅ ) = ( ⋅ ) 2 和 g ( ⋅ ) = e ( ⋅ ) g(\cdot) = e^{(\cdot)} g ( ⋅ ) = e ( ⋅ ) 的复合函数
y = f ( g ( f ( x ) ) ) = ( e x 2 ) 2 y = f(g(f(x))) = (e^{x^2})^2
y = f ( g ( f ( x ))) = ( e x 2 ) 2
如何通过反向传播计算 ∂ y / ∂ x \partial{y} / \partial{x} ∂ y / ∂ x 。
首先我们考虑正向 forward 的过程,对于一个输入 x 经过复合函数后得到输出 y:
x → 1 ( ⋅ ) 2 x 2 = a → 2 e ( ⋅ ) e a = b → 3 ( ⋅ ) 2 b 2 = y x \xrightarrow[1]{(\cdot)^2} x^2 = a \xrightarrow[2]{e^{(\cdot)}} e^{a} = b \xrightarrow[3]{(\cdot)^2} b^2 = y
x ( ⋅ ) 2 1 x 2 = a e ( ⋅ ) 2 e a = b ( ⋅ ) 2 3 b 2 = y
在复合函数 forward 的过程中会产生很多临时的“中间结果”,我们按照上面公式中所示,将对应的中间结果设为 a 和 b。
在程序设计上,首先定义了一个 Variable
类用于封装数据,其内部实际参与运算的数据类型为 np.ndarray
。
并且设计了一个 Function
类作为 Square
和 Exp
“算子”的基类。在 Function
类中我们定义了一个 forward
方法,Function
类中的 forward
方法并没有实际进行定义,只是作为“接口”,可以认为它是虚的(类比于C++中的虚函数)。实际的 forward
定义是写在子类中的,如果子类中没有定义 forward
方法,那么当子类对象调用 forward
时会默认调用基类 Function
中的 forward
方法,所以我们在基类 forward
方法中加了一个 raise NotImplementedError("...")
用于提醒用户需要在子类(实际的算子类)中定义 forward
方法。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 import numpy as npclass Variable : def __init__ (self, data ) -> None : self.data = data class Function : def __call__ (self, input : Variable ) -> Variable: x = input .data y = self.forward(x) output = Variable(y) return output def forward (self, x: np.ndarray ) -> np.ndarray: """forward: x -[f]-> y""" raise NotImplementedError("forward() have not been implemented in baseclass, " \ "and it must be implemented in a subclass" ) class Square (Function ): def forward (self, x: np.ndarray ) -> np.ndarray: return x**2 class Exp (Function ): def forward (self, x: np.ndarray ) -> np.ndarray: return np.exp(x) if __name__ == '__main__' : data = np.array(0.5 ) x = Variable(data) f1 = Square() g = Exp() f2 = Square() '''x -[f1]-> a -[g]-> b -[f2]-> y''' a = f1(x) b = g(a) y = f2(b) print (y.data)
backward
在 forward 的基础上,根据链式法则可以推导出
∂ y ∂ x = ∂ a ∂ x ⋅ ( ∂ b ∂ a ⋅ ( ∂ y ∂ b ⋅ ( ∂ y ∂ y ) ) ) \frac{\partial y}{\partial x} =
\frac{\partial a}{\partial x} \cdot
\left(
\frac{\partial b}{\partial a} \cdot
\left(
\frac{\partial y}{\partial b} \cdot
\left(\frac{\partial y}{\partial y}\right)
\right)
\right)
∂ x ∂ y = ∂ x ∂ a ⋅ ( ∂ a ∂ b ⋅ ( ∂ b ∂ y ⋅ ( ∂ y ∂ y ) ) )
其中的 ∂ y / ∂ y = 1 \partial y / \partial y = 1 ∂ y / ∂ y = 1 ,这是为了后面编程中统一表示的方便而加上的。将这个求导公式换成 backward 的表达形式
∂ y ∂ x ← 1 f ′ ( x ) × ∂ y ∂ a ← 2 g ′ ( a ) × ∂ y ∂ b ← 3 f ′ ( b ) × ∂ y ∂ y \frac{\partial y}{\partial x}\xleftarrow[1]{f'(x)\times}
\frac{\partial y}{\partial a}\xleftarrow[2]{g'(a)\times}
\frac{\partial y}{\partial b}\xleftarrow[3]{f'(b)\times}
\frac{\partial y}{\partial y}
∂ x ∂ y f ′ ( x ) × 1 ∂ a ∂ y g ′ ( a ) × 2 ∂ b ∂ y f ′ ( b ) × 3 ∂ y ∂ y
其中的箭头可以看作是一种函数——backward,其输入为 最终结果 到 forward 输出 的偏导,输出为 最终结果 到 forward 输入 的偏导。例如
在 forward 3 中我们有
b → 3 f y = f ( b ) b\xrightarrow[3]{f}y = f(b)
b f 3 y = f ( b )
forward 3 的输入为 b,输出为 y,并且 y 为最终结果。那么对应 backward 3 的输入为 最终结果 y 到 forward 3 输出 y 的偏导 ∂ y / ∂ y = 1 \partial y / \partial y = 1 ∂ y / ∂ y = 1 . 在 backward 3 中,对其输入 ∂ y / ∂ y \partial y / \partial y ∂ y / ∂ y 乘以一个 f f f 在 forward 3 输入 处 b b b 的偏导 ∂ f ( t ) ∂ t ∣ t = b \frac{\partial f(t)}{\partial t} |_{t=b} ∂ t ∂ f ( t ) ∣ t = b 有
∂ y ∂ b = f ′ ( b ) ⋅ ∂ y ∂ y ← f ′ ( b ) × ∂ y ∂ y \frac{\partial y}{\partial b} = f'(b)\cdot \frac{\partial y}{\partial y} \xleftarrow{f'(b)\times} \frac{\partial y}{\partial y}
∂ b ∂ y = f ′ ( b ) ⋅ ∂ y ∂ y f ′ ( b ) × ∂ y ∂ y
在 forward 2 中我们有
a → 2 g b = g ( a ) a\xrightarrow[2]{g}b = g(a)
a g 2 b = g ( a )
forward 2 的输入为 a,输出为 b。那么对应 backward 2 的输入为 最终结果 y 到 forward 2 输出 b 的偏导 ∂ y / ∂ b \partial y / \partial b ∂ y / ∂ b . 在 backward 2 中,其对输入 ∂ y / ∂ b \partial y / \partial b ∂ y / ∂ b 乘以一个 g g g 在 forward 2 输入 a a a 处的偏导 ∂ g ( t ) ∂ t ∣ t = a \frac{\partial g(t)}{\partial t} |_{t=a} ∂ t ∂ g ( t ) ∣ t = a 有
∂ y ∂ a = g ′ ( a ) ⋅ ∂ y ∂ b ← g ′ ( a ) × ∂ y ∂ b \frac{\partial y}{\partial a} = g'(a)\cdot \frac{\partial y}{\partial b} \xleftarrow{g'(a)\times} \frac{\partial y}{\partial b}
∂ a ∂ y = g ′ ( a ) ⋅ ∂ b ∂ y g ′ ( a ) × ∂ b ∂ y
在 forward 1 中我们有
x → 1 f a = f ( x ) x\xrightarrow[1]{f}a = f(x)
x f 1 a = f ( x )
forward 1 的输入为 x,输出为 a。那么对应 backward 1 的输入为 最终结果 y 到 forward 1 输出 a 的偏导 ∂ y / ∂ a \partial y / \partial a ∂ y / ∂ a . 在 backward 1 中,其对输入 ∂ y / ∂ a \partial y / \partial a ∂ y / ∂ a 乘以一个 f f f 在 forward 1 输入 x x x 处的偏导 ∂ f ( t ) ∂ t ∣ t = x \frac{\partial f(t)}{\partial t} |_{t=x} ∂ t ∂ f ( t ) ∣ t = x 有
∂ y ∂ x = f ′ ( x ) ⋅ ∂ y ∂ a ← f ′ ( x ) × ∂ y ∂ a \frac{\partial y}{\partial x} = f'(x)\cdot \frac{\partial y}{\partial a} \xleftarrow{f'(x)\times} \frac{\partial y}{\partial a}
∂ x ∂ y = f ′ ( x ) ⋅ ∂ a ∂ y f ′ ( x ) × ∂ a ∂ y
在程序设计上,首先为 Variable
类添加一个 grad
成员,用于存储最终结果到 Variable
对象的梯度。并且在基类 Function
中添加一个成员 self.input
,用于存储在 forward 过程中每个算子的输入,目的是为后面进行 backward 时求算子在输入处的梯度。
额外需要注意的是,针对在 forward 过程中复用的算子,需要创建不同的对象。例如 Square
在 forward 过程中被使用了两次,分别创建了 f1
和 f2
两个对象。
x → f a b → f y ∂ y ∂ x ← f ′ ( x ) × ∂ y ∂ a ∂ y ∂ b ← f ′ ( b ) × ∂ y ∂ y \begin{array}{cc}
x \xrightarrow{f} a & b \xrightarrow{f} y\\
\frac{\partial y}{\partial x}\xleftarrow{f'(x)\times} \frac{\partial y}{\partial a} &
\frac{\partial y}{\partial b}\xleftarrow{f'(b)\times}\frac{\partial y}{\partial y}
\end{array}
x f a ∂ x ∂ y f ′ ( x ) × ∂ a ∂ y b f y ∂ b ∂ y f ′ ( b ) × ∂ y ∂ y
f1
和 f2
有相同点和不同点,相同点在于它们都是 Square
对象,它们执行同一段 forward 的计算逻辑(函数映射)。但是不同点在于 f1
和 f2
具有不同的输入和输出,f1
的输入是 x
输出是 a
,f2
的输入是 b
输出是 y
,它们对应的“算子结点”在内部保存的输入值不一样,f1.input.data = x
而 f2.input.data = b
。这样在计算 backward 时,f1.backward
在内部需要计算 f'(x)
的值,而 f2.backward
在内部需要计算 f'(b)
的值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 import numpy as npfrom typing import Any class Variable : def __init__ (self, data ) -> None : self.data = data self.grad = None class Function : def __call__ (self, input : Variable ) -> Variable: self.input = input x = input .data y = self.forward(x) output = Variable(y) return output def forward (self, x: np.ndarray ) -> np.ndarray: """forward: x -[f]-> y""" raise NotImplementedError("forward() have not been implemented in baseclass, " \ "and it must be implemented in a subclass" ) def backward (self, gy: np.ndarray ) -> np.ndarray: """ forward: x -[ f (·) ]-> y backward: ∂(.)/∂x <-[ f'(x) ]- ∂(.)/∂y \_∂f/∂x *_/ """ raise NotImplementedError("backward() have not been implemented in baseclass, " \ "and it must be implemented in a subclass" ) class Square (Function ): def forward (self, x: np.ndarray ) -> np.ndarray: return x**2 def backward (self, gy: np.ndarray ) -> np.ndarray: x = self.input .data gx = 2 * x * gy return gx class Exp (Function ): def forward (self, x: np.ndarray ) -> np.ndarray: return np.exp(x) def backward (self, gy: np.ndarray ) -> np.ndarray: x = self.input .data gx = np.exp(x) * gy return gx if __name__ == '__main__' : data = np.array(0.5 ) x = Variable(data) f1 = Square() g = Exp() f2 = Square() '''x -[f1]-> a -[g]-> b -[f2]-> y''' a = f1(x) b = g(a) y = f2(b) ''' ∂y/∂x = ∂a/∂x * (∂b/∂a * (∂y/∂b * (∂y/∂y))) ∂y/∂x <-[f1'(x)]- ∂y/∂a <-[g'(a)]- ∂y/∂b <-[f2'(b)]- ∂y/∂y \_∂a/∂x*_/ \_∂b/∂a*_/ \_∂y/∂b*_/ ''' y.grad = np.array(1.0 ) b.grad = f1.backward(y.grad) a.grad = g.backward(b.grad) x.grad = f2.backward(a.grad) print (x.grad)