矩阵乘积 MatMul 的反向传播

有公式 \mathbf{y} = \mathbf{x}W ,其中 \mathbf{x} 是 D * M 矩阵,W 是 M * N 权重矩阵;另有损失函数 L 是对 \mathbf{y} 的函数,假设 L 对 y 的偏导已知(反向传播时是这样的),求 L 关于矩阵 \mathbf{x} 的偏导

答案见下式,非常简洁;求一个标量对于矩阵的偏导,这个问题一度困惑了我很长一段时间;在学微积分的时候,求的一直都是 y 对标量 x 的导数或者偏导(多个自变量),对矩阵的偏导该如何算,不知啊;看了普林斯顿的微积分读本,托马斯微积分也看了,都没提到

\frac{\partial L}{\partial \mathbf{x}}=\frac{\partial L}{\partial \mathbf{y}}W^T

这里的关键在于如何理解 \frac{\partial L}{\partial \mathbf{x}} ,其实就是一种记法,也就是分别计算 L 对 x 中所有项的偏导,然后写成矩阵形式;为了表述方便,我们令上式右边为 A , 那么对于 \mathbf{x} 中的第 ij 项(第 i 行第 j 列), 则必有\frac{\partial L}{\partial x_{ij}} = A_{ij} ,我们只要能证明这一点就可以了

根据链式法则(可参考附录), 要计算 \frac{\partial L}{\partial x_{ij}} ,我们先计算 L 对 y 的偏导(已知项),然后乘以 y 对 x 的偏导;注意并不需要考虑 y 中的所有项,因为按照矩阵乘法定义,x_{ij} 只参与了 y 第 i 行 (y_{i1}, y_{i2},...y_{in}) 的计算,其中 y_{ik} = \sum\limits_{l=1}^Mx_{il}W_{lk}

\begin{split} \frac{\partial L}{\partial x_{ij}}&=\sum_{k=1}^N\frac{\partial L}{\partial y_{ik}}\frac{\partial y_{ik}}{\partial x_{ij}}\\ &=\sum_{k=1}^N\frac{\partial L}{\partial y_{ik}}W_{jk} \text{$\qquad (\frac{\partial y_{ik}}{\partial x_{ij}}=W_{jk})$}\\ &=\sum_{k=1}^N\frac{\partial L}{\partial y_{ik}}W^T_{kj} \text { $\qquad(W_{jk}=W^T_{kj}$)} \end{split}

也就是 L 对 x_{ij} 的偏导等于 L 对 y 第 i 行的偏导(可视为向量)与 W^T 第 j 列(向量)的点积,根据矩阵乘法定义(矩阵 AB的第 ij 项等于A的第 i 行与 B 的第 j 列的点积),可得上述答案

现在我们来计算 L 关于权重矩阵 W 的偏导

同样按照链式法则,我们先计算 L 对 y 的偏导(已知项),然后乘以 y 对 w 的偏导;按照矩阵乘法 w_{ij} 参与了 y 第 j 列所有项的计算,其中 y_{kj} = \sum\limits_{l=1}^Mx_{kl}W_{lj}

\begin{split} \frac{\partial L}{\partial w_{ij}}&=\sum_{k=1}^D\frac{\partial L}{\partial y_{kj}}\frac{\partial y_{kj}}{\partial w_{ij}}\\ &=\sum_{k=1}^D\frac{\partial L}{\partial y_{kj}}x_{ki} \text{$\qquad (\frac{\partial y_{kj}}{\partial w_{ij}}=x_{ki})$}\\ &=\sum_{k=1}^Dx^T_{ik}\frac{\partial L}{\partial y_{kj}} \end{split}

也就是 L 对 W_{ij} 的偏导等于 x^T 第 i 行与L 对 y 第 j 列项的偏导的点积,按照矩阵乘法定义可得

\frac{\partial L}{\partial W} = x^T\frac{\partial L}{\partial y}

附录:

链式法则 如果函数 w = f(x, y) 有连续的偏导数 f_x 和f_y 并且 x = x(t) , y = y(t) 可微,那么有

\frac{dw}{dt}=\frac{\partial f}{\partial x}\frac{dx}{dt}+\frac{\partial f}{\partial y}\frac{dy}{dt}

参考 托马斯微积分第 11 版,14.4 节 链式法则 Chain Rule