神经网络的系数求导方法

By Z.H. Fu
切问录 www.fuzihao.org

神经网络在学习的时候通常采用误差反馈的方法,误差反馈的实质是一种梯度下降,由于梯度下降的速度较慢,所以又通常采用SGD等方法实现对计算效率的优化。本文给出了求神经网络似然函数对所有参数导数的方法,该方法基于微分的链式法则,利用该导数实现了基于BFGS的神经网络,取得了良好的效果。同时,除了神经网络之外,该方法对其他复合函数参数的求导也同样有效。

网络结构及函数定义

这里采用一个单隐层的神经网络作为例子,该神经网络有6个输入结点,三个隐层结点和一个输出结点,\(a_1\)是输入矩阵,矩阵的每行是一个样本,每列是一个特征,其数量就是输入层的结点数,这里一共有六列。结构如图所示:

NN 其传递关系如下: \[\begin{align}&z_2 = a_1 w_1+b_1 \\ & a_2 = h(z_2) \\ &z_3 = a_2 w_2+b_2\\ & a_3 = h(z_3)\\ &c = l(a_3)\\ &v = M:c \end{align}\]

其中 \[\begin{align}&h(x)=\frac{1}{1+e^{-x}} \\ &l(a)=\ln(a)\circ y+\ln(1-a)\circ (1-y) \\ &M=[\frac{1}{N},\frac{1}{N},\cdots,\frac{1}{N}]^T \end{align}\]

这里\(\circ\)是Hadamard product,表示左右矩阵(向量)对应元素相乘,左右元素大小相同,得出的结果大小不变。\(:\)是Frobenius product,表示对应元素相乘后再相加,它又可以表示为\(A:B=tr(A^T,B),\)\(h(x),l(a)\)表示对自变量的每一个元素应用该函数。

链式法则

这里链式法则其实就是微分的传递,我们首先给出导数矩阵的形式,若\(dY=A:dX\),则认为\(A\)\(X\)的导数。下面我们给出\(v\)的微分和参数微分的关系式: \[\begin{align} &dv=M:dc \\ &dc=l'(a_3)\circ d a_3\\ &da_3=h'(z_3) \circ d z_3 \\ &dz_3=a_2\cdot dw_2+da_2\cdot w_2+db_2\otimes 1_{n \times 1}\\ &da_2=h'(z_2)\circ dz_2\\ &dz_2=a_1\cdot d w_1+da_1\cdot w_1+db_1\otimes1_{n\times 1} \end{align}\] 其中,\(da_1\cdot w_1\)\(a_1\)是输入常数矩阵,故而微分为1。我们分别来看\(v\)的微分和四个参数\(w_2,b_2,w_1,b_1\)的关系。参数微分依赖关系如图:

1、\(dv\)\(d w_2\)的关系

我们将上面的式子依次带入得: \[dv=M:l'(a_3) \circ h'(z_3) \circ (a_2\cdot dw_2+da_2\cdot w_2+db_2\otimes 1_{n \times 1})\] 我们来看含有\(d w_2\)的项,为: \[\begin{align} &M:l'(a_3) \circ h'(z_3) \circ (a_2\cdot dw_2)\\ &=M\circ l'(a_3) \circ h'(z_3) : (a_2\cdot dw_2)\\ &=tr((M\circ l'(a_3) \circ h'(z_3))^T \cdot (a_2\cdot dw_2))\\ &=tr((a_2^T\cdot(M\circ l'(a_3) \circ h'(z_3)))^T \cdot dw_2)\\ &=a_2^T\cdot(M\circ l'(a_3) \circ h'(z_3)):dw_2 \end{align}\] 故而,\(w_2\)的导数为\(a_2^T\cdot(M\circ l'(a_3) \circ h'(z_3))\)

2、\(dv\)\(d b_2\)的关系

我们看上面\(dv\)式中含有\(d b_2\)的项,为: \[\begin{align} &M:l'(a_3) \circ h'(z_3) \circ (db_2\otimes 1_{n \times 1})\\ &=M\circ l'(a_3) \circ h'(z_3) : (db_2\otimes 1_{n \times 1})\\ &=1_{n \times 1}^T\cdot M\circ l'(a_3) \circ h'(z_3) : db_2\\ \end{align}\] 所以,\(b_2\)的导数为\(1_{n \times 1}^T\cdot M\circ l'(a_3) \circ h'(z_3)\)

3、\(dv\)\(d w_1\)的关系

我们接着带入\(dv\)表达式直到\(dw_1,db_1\)出现,得: \[\begin{align} dv&=M:l'(a_3) \circ h'(z_3) \circ (a_2\cdot dw_2+da_2\cdot w_2+db_2\otimes 1_{n \times 1}) \\ &=M:l'(a_3) \circ h'(z_3) \circ (a_2\cdot dw_2+(h'(z_2)\circ (a_1\cdot d w_1+da_1\cdot w_1+db_1\otimes1_{n\times 1}))\cdot w_2+db_2\otimes 1_{n \times 1}) \\ \end{align}\] 我们看其中含有\(dw_1\)的项,为: \[\begin{align} &M:l'(a_3) \circ h'(z_3) \circ ((h'(z_2)\circ (a_1\cdot d w_1))\cdot w_2) \\ &=M\circ l'(a_3) \circ h'(z_3) : ((h'(z_2)\circ (a_1\cdot d w_1))\cdot w_2) \\ \end{align}\] 为$A:B C $的形式,而 \[\begin{align} A:B\cdot C=tr(A^T\cdot B\cdot C)=tr(C \cdot A^T \cdot B)=tr((A\cdot C^T)^T\cdot B)=A\cdot C^T:B \end{align}\] 所以,上式 \[\begin{align} &M\circ l'(a_3) \circ h'(z_3) : ((h'(z_2)\circ (a_1\cdot d w_1))\cdot w_2) \\ &=(M\circ l'(a_3) \circ h'(z_3) \cdot w_2^T): (h'(z_2)\circ (a_1\cdot d w_1)) \\ &=((M\circ l'(a_3) \circ h'(z_3) \cdot w_2^T)\circ h'(z_2)): (a_1\cdot d w_1) \\ &=a_1^T\cdot ((M\circ l'(a_3) \circ h'(z_3) \cdot w_2^T)\circ h'(z_2)): d w_1 \\ \end{align}\] 所以,\(w_1\)的导数为\(a_1^T\cdot ((M\circ l'(a_3) \circ h'(z_3) \cdot w_2^T)\circ h'(z_2))\)

4、\(dv\)\(d b_1\)的关系

由上面的式子,我们写出\(dv\)表达式中含\(d b_1\)的项,为: \[\begin{align} &M:l'(a_3) \circ h'(z_3) \circ ((h'(z_2)\circ (db_1\otimes1_{n\times 1}))\cdot w_2) \\ &=M\circ l'(a_3) \circ h'(z_3) : ((h'(z_2)\circ (db_1\otimes1_{n\times 1}))\cdot w_2) \\ &=M\circ l'(a_3) \circ h'(z_3)\cdot w_2^T \circ h'(z_2): (db_1\otimes1_{n\times 1}) \\ &=(M\circ l'(a_3) \circ h'(z_3)\cdot w_2^T \circ h'(z_2))^T\cdot 1_{n\times 1} : db_1 \end{align}\] 由此,我们得到了\(db_1\)的导数,为$(Ml’(a_3) h’(z_3)w_2^T h’(z_2))^T1_{n1} $ \[\begin{align} \end{align}\]

总结

可以看到这种方法不仅课用于神经网络的参数求导,对于除神经网络以外的其他复合函数参数的导数也能够计算。随着网络层数的增加,表达式复杂度上升。