循环神经网络Recurrent Neural Networks

RNN的数学描述
参考零基础入门深度学习(5) - 循环神经网络
输入层
网络的输入是一串m维向量序列 x1,x2,⋯,xt,⋯
x1=x11x21⋮xm1,x2=x12x22⋮xm2,⋯,xt=x1tx2t⋮xmt,⋯
循环层
网络的状态是一串n维向量序列 s0,s1,s2⋯,st,⋯
s1ts2t⋮snt=fu11u21⋮un1u12u22⋮un2⋯⋯⋱⋯u1mu2m⋮unmx1tx2t⋮xmt+w11w21⋮wn1w12w22⋮wn2⋯⋯⋱⋯w1nw2n⋮wnns1t−1s2t−1⋮snt−1+b1Rb2R⋮bnRt=1,2,⋯
输出层
网络的输出是一串m维的向量序列 o1,o2,⋯,ot,⋯
o1to2t⋮omt=gv11v21⋮vm1v12v22⋮vm2⋯⋯⋱⋯v1nv2n⋮vmns1ts2t⋮snt+b1Ob2O⋮bmOt=1,2,⋯
网络的输出
网络在 t 时刻的输出 ot 由前面各时刻的输入 xt,xt−1,⋯,x1和初始状态 s0 决定
(下面的推导式中省略了偏置项 b)
ot=g(Vst)=g(Vf(Uxt+Wst−1))=g(Vf(Uxt+Wf(Uxt−1+Wst−2)))⋮=g(Vf(Uxt+Wf(Uxt−1+Wf(Uxt−2+⋯+Wf(Ux1+Ws0)))))
网络输出的误差
网络在每个 t 时刻的输出 ot 都对应一个目标向量 tt (target), 每个时刻都对应一个误差, 用Et来表示 , Et 是关于 ot和 tt 的函数, 例如采用二范数的平方表示误差, 误差函数如下计算
Et=21∥ot−tt∥22=21i=1∑m(oit−tit)2
梯度的计算(Back Propagate Through Time, BPTT)
循环层到输出层
记输出层 t 时刻的输入向量为 ξt
o1to2t⋮omt=gξ1tξ2t⋮ξmt,ξ1tξ2t⋮ξmt=v11v21⋮vm1v12v22⋮vm2⋯⋯⋱⋯v1nv2n⋮vmns1ts2t⋮snt+b1Ob2O⋮bmO
∂vij∂Et∂biO∂Et=∂ξit∂Et⋅∂vij∂ξit=∂ξit∂Et⋅sjt=∂ξit∂Et⋅∂biO∂ξit=∂ξit∂Et⋅1i=1,⋯,mj=1,⋯,n
向量化计算梯度
∂bO∂Et=∂ξ1t∂Et∂ξ2t∂Et⋮∂ξmt∂Et,∂V∂Et=∂ξ1t∂Et∂ξ2t∂Et⋮∂ξmt∂Et[s1ts2t⋯snt]=∂ξ1t∂Ets1t∂ξ2t∂Ets1t⋮∂ξmt∂Ets1t∂ξ1t∂Ets2t∂ξ2t∂Ets2t⋮∂ξmt∂Ets2t⋯⋯⋱⋯∂ξ1t∂Etsnt∂ξ2t∂Etsnt⋮∂ξmt∂Etsnt
输入层到循环层
记循环层 t 时刻的输入向量为 ηt
s1ts2t⋮snt=fη1tη2t⋮ηnt,η1tη2t⋮ηnt=u11u21⋮un1u12u22⋮un2⋯⋯⋱⋯u1mu2m⋮unmx1tx2t⋮xmt+w11w21⋮wn1w12w22⋮wn2⋯⋯⋱⋯w1nw2n⋮wnns1t−1s2t−1⋮snt−1+b1Rb2R⋮bnR
关于矩阵U的偏导
由上面的记号, t 时刻循环层的输入为ηt, ηt 是网络在 t 时刻的输入 xt 和 上一时刻的状态 st−1 的线性变换
ηt=Uxt+Wst−1+bRst−1=f(ηt−1)
下面的公式推导出一个 ∂Et/∂U 关于时间的递推式, 我们记 ∂U∂Et(t) 为 t 时刻网络输出的误差 E 关于
∂U∂Et(ηt=Uxt+Wst−1+bR)→将∂ηt∂Et乘进括号中去→(∂ηt−1∂Wst−1=∂ηt−1∂ηt)→=∂ηt∂Et∂U∂ηt=∂ηt∂Et(∂U∂Uxt+∂U∂Wst−1)=∂ηt∂Et(∂U∂Uxt+W∂ηt−1∂st−1∂U∂ηt−1)=∂ηt∂Et∂U∂Uxt+∂ηt∂Et∂ηt−1∂Wst−1∂U∂ηt−1=∂ηt∂Et∂U∂Uxt+∂ηt∂Et∂ηt−1∂ηt∂U∂ηt−1=∂ηt∂Et∂U∂Uxt+∂ηt−1∂Et∂U∂ηt−1
由这个递推式可以得到
∂U∂Et=∂ηt∂Et∂U∂ηt=∂ηt∂Et∂U∂Uxt+∂ηt−1∂Et∂U∂ηt−1=∂ηt∂Et∂U∂Uxt+∂ηt−1∂Et∂U∂Uxt−1+∂ηt−2∂Et∂U∂ηt−2=∂ηt∂Et∂U∂Uxt+∂ηt−1∂Et∂U∂Uxt−1+∂ηt−2∂Et∂U∂Uxt−2+⋯+∂η2∂Et∂U∂Ux2+∂η1∂Et∂U∂Ux1
计算∂ηk∂Et∂U∂Uxk
计算∂ηt∂Et
∂ηt∂Et记为=∂ξt∂Et∂st∂ξt∂ηt∂st=[∂ξ1t∂Et∂ξ2t∂Et⋯∂ξmt∂Et]∂s1t∂ξ1t∂s1t∂ξ2t⋮∂s1t∂ξmt∂s2t∂ξ1t∂s2t∂ξ2t⋮∂s2t∂ξmt⋯⋯⋱⋯∂snt∂ξ1t∂snt∂ξ2t⋮∂snt∂ξmt∂η1t∂s1t∂η1t∂s2t⋮∂η1t∂snt∂s2t∂s1t∂η2t∂s2t⋮∂η2t∂snt⋯⋯⋱⋯∂snt∂s1t∂ηnt∂s2t⋮∂ηnt∂snt=[∂ξ1t∂Et∂ξ2t∂Et⋯∂ξmt∂Et]v11v21⋮vm1v12v22⋮vm2⋯⋯⋱⋯v1nv2n⋮vmn∂η1t∂s1t0⋮00∂η2t∂s2t⋮0⋯⋯⋱⋯00⋮∂ηnt∂snt=[∂η1t∂s1ti=1∑m(∂ξit∂Etvi1),∂η2t∂s2ti=1∑m(∂ξit∂Etvi2),⋯,∂ηnt∂snti=1∑m(∂ξit∂Etvin)]=[δ1ttδ2tt⋯δntt]
∂ηt∂Et 的结果记为 δtt, 称为循环层 t 时刻(第二个 t)的输入的误差项 (网络 t 时刻输出的误差关于循环层 t 时刻输入的偏导数)
计算∂ηk∂Et
∂ηt−1∂ηt=∂ηt−1∂Wst−1=W∂ηt−1∂st−1=W∂η1t−1∂s1t−1∂η1t−1∂s2t−1⋮∂η1t−1∂snt−1∂η2t−1∂s1t−1∂η2t−1∂s2t−1⋮∂η2t−1∂snt−1⋯⋯⋱⋯∂ηnt−1∂s1t−1∂ηnt−1∂s2t−1⋮∂ηnt−1∂snt−1=W∂η1t−1∂s1t−10⋮00∂η2t−1∂s2t−1⋮0⋯⋯⋱⋯00⋮∂ηnt−1∂snt−1=Wf′(η1t−1)0⋮00f′(η2t−1)⋮0⋯⋯⋱⋯00⋮f′(ηnt−1)
∂ηk∂Et记为=∂ξt∂Et∂st∂ξt∂ηt∂st(∂ηt−1∂ηt⋯∂ηk∂ηk+1)=[δ1ttδ2tt⋯δntt]i=(t−1)∏kWf′(η1i)⋮0⋯⋱⋯0⋮f′(ηni)=[δ1tkδ2tk⋯δntk](t≥k≥1)
∂ηk∂Et 的结果记为 δtk, 称为循环层 k 时刻输入的误差项 (网络 t 时刻输出的误差关于循环层 k 时刻输入的偏导数)
实际计算中我们会一步一步地计算δtt,δt(t−1),⋯,δt1, 而不是使用连乘运算
[δ1t(t−1)δ2t(t−1)⋯δnt(t−1)][δ1t(t−2)δ2t(t−2)⋯δnt(t−2)][δ1t1δ2t1⋯δnt1]=[δ1tkδ2tk⋯δntk]Wf′(η1t−1)⋮0⋯⋱⋯0⋮f′(ηnt−1)=[δ1t(t−1)δ2t(t−1)⋯δnt(t−1)]Wf′(η1t−2)⋮0⋯⋱⋯0⋮f′(ηnt−2)⋮=[δ1t(2)δ2t(2)⋯δnt(2)]Wf′(η11)⋮0⋯⋱⋯0⋮f′(ηn1)
计算∂U∂Uxk
∂U∂Uxk=∂u11∂η1k⋮∂un1∂η1k⋯⋱⋯∂u1m∂η1k⋮∂unm∂η1k⋮∂u11∂ηik⋮∂un1∂ηik⋯⋱⋯∂u1m∂ηik⋮∂unm∂ηik⋮∂u11∂ηnk⋮∂un1∂ηnk⋯⋱⋯∂u1m∂ηnk⋮∂unm∂ηnk=x1k0⋮0x2k0⋮0⋯⋯⋱⋯xmk0⋮0⋮0⋮x1k⋮00⋮x2k⋮0⋯⋯⋯0⋮xmk⋮01⋮i⋮n⋮0⋮0x1k0⋮0x2k⋯⋱⋯⋯0⋮0xmk(t≥k≥1)
计算∂ηk∂Et∂U∂Uxk
∂ηk∂Et⋅∂U∂Uxk=[δ1tkδ2tk⋯δntk]x1k0⋮0x2k0⋮0⋯⋯⋱⋯xmk0⋮0⋮0⋮x1k⋮00⋮x2k⋮0⋯⋯⋯0⋮xmk⋮01⋮i⋮n⋮0⋮0x1k0⋮0x2k⋯⋱⋯⋯0⋮0xmk=δ1tkδ2tk⋮δntk[x1kx2k⋯xmk](t≥k≥1)
最后结果U的梯度
∂U∂Et=k=1∑tδ1tkδ2tk⋮δntk[x1kx2k⋯xmk]
关于矩阵W的偏导
∂W∂Et(ηt=Uxt+Wst−1+bR)→(莱布尼茨法则)→(∂ηt−1∂Wst−1=∂ηt−1∂ηt)→=∂ηt∂Et∂W∂ηt=∂ηt∂Et(∂W∂Wst−1)=∂ηt∂Et(∂W∂Wst−1+W∂W∂st−1)=∂ηt∂Et∂W∂Wst−1+∂ηt∂EtW∂ηt−1∂st−1∂W∂ηt−1=∂ηt∂Et∂W∂Wst−1+∂ηt∂Et∂ηt−1∂ηt∂W∂ηt−1=∂ηt∂Et∂W∂Wst−1+∂ηt−1∂Et∂W∂ηt−1
∂W∂Et=∂ηt∂Et∂W∂ηt=∂ηt∂Et∂W∂Wst−1+∂ηt−1∂Et∂W∂ηt−1=∂ηt∂Et∂W∂Wst−1+∂ηt−1∂Et∂W∂Wst−2+∂ηt−2∂Et∂W∂ηt−2=∂ηt∂Et∂W∂Wst−1+∂ηt−1∂Et∂W∂Wst−2+∂ηt−2∂Et∂W∂Wst−3+⋯+∂η2∂Et∂W∂Ws1+∂η1∂Et∂W∂Ws0
计算∂ηk∂Et∂W∂Wsk−1
计算∂W∂W
∂W∂W=∂(w11⋮wn1⋯⋱⋯wn1⋮wnn)∂(w11⋮wn1⋯⋱⋯wn1⋮wnn)=10⋮000⋮0⋯⋯⋱⋯00⋮0⋮00⋮100⋮0⋯⋯⋱⋯00⋮0⋯⋱⋯00⋮000⋮0⋯⋯⋱⋯10⋮0⋮00⋮000⋮0⋯⋯⋱⋯00⋮1
计算∂ηk∂Et∂W∂Wsk−1
∂ηk∂Et∂W∂Wsk−1=[δ1tkδ2tk⋯δntk]10⋮000⋮0⋯⋯⋱⋯00⋮0⋮00⋮100⋮0⋯⋯⋱⋯00⋮0⋯⋱⋯00⋮000⋮0⋯⋯⋱⋯10⋮0⋮00⋮000⋮0⋯⋯⋱⋯00⋮1s1k−1s2k−1⋮snk−1=δ1tkδ2tk⋮δntk[s1k−1s2k−1⋯snk−1](t≥k≥1)
最后结果W的梯度
∂W∂Et=k=1∑tδ1tkδ2tk⋮δntk[s1k−1s2k−1⋯snk−1]
关于偏置项bR的偏导
∂bR∂Et(ηt=Uxt+Wst−1+bR)→(∂ηt−1∂Wst−1=∂ηt−1∂ηt)→=∂ηt∂Et∂bR∂ηt=∂ηt∂Et(∂bR∂bR+∂bR∂Wst−1)=∂ηt∂Et∂bR∂bR+∂ηt∂Et∂ηt−1∂Wst−1∂bR∂ηt−1=∂ηt∂Et∂bR∂bR+∂ηt∂Et∂ηt−1∂ηt∂bR∂ηt−1=∂ηt∂Et∂bR∂bR+∂ηt−1∂Et∂bR∂ηt−1
∂bR∂Et=∂ηt∂Et∂bR∂ηt=∂ηt∂Et∂bR∂bR+∂ηt−1∂Et∂bR∂ηt−1=∂ηt∂Et∂bR∂bR+∂ηt−1∂Et∂bR∂bR+∂ηt−2∂Et∂bR∂ηt−2=∂ηt∂Et∂bR∂bR+∂ηt−1∂Et∂bR∂bR+∂ηt−2∂Et∂bR∂bR+⋯+∂η1∂Et∂bR∂bR
计算∂ηk∂Et∂bR∂bR
∂ηk∂Et∂bR∂bR=∂ηk∂Et⋅Inn=∂ηk∂Et=δ1tkδ2tk⋮δntk
最后结果 bR 的梯度
∂bR∂Et=k=1∑tδ1tkδ2tk⋮δntk