链式法则是一个数学定理,用于计算复合函数的导数。从数学上讲,深度学习模型就是复合函数。因此,理解其导数的计算过程对于训练它们非常重要,接下来的几章将详述这一点。
在数学上,这个定理看起来较为复杂,对于给定的值 ,我们有:
其中 只是一个伪变量,代表函数的输入。
当描述具有一个输入和一个输出的函数 的导数时,可以将代表该函数导数的函数表示为 。可以用其他伪变量替代 ,这样做并不会对结果造成影响,就像 和 表示同一个意思一样。
稍后,我们将处理包含 多个 输入(例如 和 )的函数。一旦碰到这种情况,区分 和 之间的不同含义就是有意义的。
这就是为什么在前面的公式中,我们在所有的导数中将 放在了底部: 和 都是接受一个输入并产生一个输出的函数,在这些情况下(有一个输入和一个输出的函数),我们将在导数符号中使用 。
对理解链式法则而言,本节中的数学公式不太直观。对此,盒子表示法会更有帮助。下面通过简单的 示例来解释导数“应该”是什么,如图 1-8 所示。
图 1-8:链式法则示意图
直观地说,使用图 1-8 中的示意图,复合函数的导数应该是其组成函数的导数的乘积。假设在第一个函数中输入 5,并且当 时,第一个函数的导数是 3,那么用公式表示就是 。
然后取第一个盒子中的函数值,假设它是 1,即 ,再计算第二个函数 在这个值上的导数,即计算 。如图 1-8 所示,这个值是 -2。
想象这些函数实际上是串在一起的,如果将盒子 2 对应的输入更改 1 单位会导致盒子 2 的输出产生 -2 单位的变化,将盒子 2 对应的输入更改 3 单位则会导致盒子 2 的输出变化 -6(-2×3)单位。这就是为什么在链式法则的公式中,最终结果是一个乘积: 。
利用数学和示意图这两个维度,我们可以通过使用链式法则来推断嵌套函数的输出相对于其输入的导数值。那么计算这个导数的代码如何编写呢?
下面对此进行编码,并证明按照这种方式计算的导数会产生“看起来正确”的结果。这里将使用
square
函数以及
sigmoid
函数
,后者在深度学习中非常重要:
def sigmoid(x: ndarray) -> ndarray:
'''
将sigmoid函数应用于输入ndarray中的每个元素。
'''
return 1 / (1 + np.exp(-x))
现在编写链式法则:
def chain_deriv_2(chain: Chain,
input_range: ndarray) -> ndarray:
'''
使用链式法则计算两个嵌套函数的导数:( f 2 f 1(x))′ = f 2′( f 1(x))*f 1′(x) 。
'''
assert len(chain) == 2, \
"This function requires 'Chain' objects of length 2"
assert input_range.ndim == 1, \
"Function requires a 1 dimensional ndarray as input_range"
f1 = chain[0]
f2 = chain[1]
# df1/dx
f1_of_x = f1(input_range)
# df1/du
df1dx = deriv(f1, input_range)
# df2/du(f1(x))
df2du = deriv(f2, f1(input_range))
# 在每一点上将这些量相乘
return df1dx * df2du
图 1-9 绘制了结果,并展示了链式法则的有效性:
PLOT_RANGE = np.arange(-3, 3, 0.01)
chain_1 = [square, sigmoid]
chain_2 = [sigmoid, square]
plot_chain(chain_1, PLOT_RANGE)
plot_chain_deriv(chain_1, PLOT_RANGE)
plot_chain(chain_2, PLOT_RANGE)
plot_chain_deriv(chain_2, PLOT_RANGE)
图 1-9:链式法则的有效性
链式法则似乎起作用了。当函数向上倾斜时,导数为正;当函数向下倾斜时,导数为负;当函数未发生倾斜时,导数为零。
因此,实际上只要各个函数本身是基本可微的,就可以通过数学公式和代码计算嵌套函数(或复合函数)的导数,例如 。
从数学上讲,深度学习模型是这些基本可微函数的长链。建议花时间手动执行稍长一点的详细示例(参见 1.5 节),这样有助于直观地理解链式法则,包括其运行方式以及在更复杂的模型中的应用。