购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

2.4 实现线性回归、多项式回归和逻辑回归

2.4.1 PyTorch实现线性回归

线性回归是机器学习算法中最简单、直观的算法,本节借助PyTorch实现这一算法。实现算法之前,有必要对线性回归算法的思想做简单介绍。假设数据集 D ={( x 1 , y 1 ),( x 2 , y 2 ),( x 3 , y 3 ),…,( x n , y n )},由 n 个数据对组成,线性回归任务希望训练出一个函数 f ( x ),使 f ( x i )= Wx i + b y i 的误差尽可能小。最关键的就是求取 W b ,其求取过程和前面讲的一样,需要找出一个合适的损失值,在线性回归中一般使用MSE作为损失函数, ,接下来要做的事情就是找出 W b ,使其满足:

求解方法也特别简单,因为要使误差最小化,实际上就是求上式的最小值,套用数学中的方法就是求偏导数,令 W b 的偏导数等于0,便可求出 W b

通过上面的式子可以求出 W b 。在一般情况下,我们不使用这种直接计算的方式求 W b ,因为计算速度特别慢,而且有些情况根本无法求解,特别是基于矩阵计算的时候。因此,在做线性方程参数训练时,通常使用梯度下降等优化算法不断地迭代和优化,最终求解 W b

然后随机初始化一个二维数据集,数据中只包含 X 轴和 Y 轴坐标,然后使用PyTorch训练一个一元线性回归模型模拟这些数据集,使误差最小。

随机生成的二维数据可视化结果如图2.20所示。

图2.20 随机生成的二维数据可视化结果

现在要做的事情是找出一条直线,最大限度地逼近这些点,使误差最小。要达到这个目的需要借助PyTorch训练出一个一元线性模型,首先使用from_numpy方法将上面生成的数据转换成Tensor。

然后借助PyTorch的nn.Modlue搭建线性模型,新建LinearRegression类继承nn.Module。

因为随机生成的输入数据x和输出数据y都是一维的,所以在__init__方法中声明的linear模型的输入和输出都是1。在forward方法中简单地调用linear模型,将x传入即可。下面开始创建模型,优化器采用SGD,损失函数采用MSELoss。

每迭代5个epoch打印出拟合的直线图,如图2.21所示。

从图2.21可以看出,随着迭代次数的增加,损失值越来越小,直线拟合得越来越好,至此就完成了简单的一元线性回归的任务。一元线性回归比较简单,2.4.2节介绍的多项式回归比较复杂。在参考代码之前,请读者先独立思考如何实现多项式回归。

图2.21 一元线性回归模型的优化过程

2.4.2 PyTorch实现多项式回归

一元线性回归模型虽然能拟合出一条直线,但是精度仍然欠佳,如2.4.1节所示,拟合的直线并不能穿过每个点,对于复杂的拟合任务,我们可以采用多项式回归提高拟合的精度。多项式回归其实就是将特征的次数提高,如前面一元线性回归的 x 是一次的,实际上我们可以采用二次、三次,甚至更高的次数进行拟合,如采用 y = x 2 + x 4 等更复杂的方式拟合数据。增加模型的复杂度必然会带来过拟合的风险,因此需要采用正则化损失的方式减少过拟合,提升模型的泛化能力。

接下来便实现一个多项式回归,用模型拟合一个复杂的多项式方程,方程如下:

f ( x )=-1.13 x -2.14 x 2 +3.15 x 3 -0.01 x 4 +0.512

借助该公式随机产生50个数据点,可视化图像如图2.22所示。

图2.22 多项式回归的可视化图像

我们期望通过PyTorch多项式回归学习上面方程中的权重参数 w 1 w 2 w 3 w 4 ,现在输入不再是一元函数的一维,而是现在的四维。输入变成一个矩阵形式:

式中, 表示第 n 个样本的第四个特征值。

下面编写一个辅助方法,将输入数据拼接成如上所示的矩阵形式。

有了标准的输入矩阵 X ,还缺少 y 的值,而 y 是使用 f ( x )方程计算出来的,这里把它固定成一个方法,用输入数据 x 计算标准的 y 值。

上面的代码使用了Tensor上的mm方法,表示矩阵相乘(Matrix Multiplication)。辅助方法都构建好之后,就用上面的两个辅助方法批量生成训练数据,用生成的数据训练模型,然后用于预测上面生成的随机数据,了解模型与上面生成的数据的吻合程度。

下面创建多项式回归模型,其实实现起来很简单,仍然使用前面实现线性回归的方式,使用torch.nn.Linear模型,区别是输入为4,对应输入[ x 1 , x 2 , x 3 , x 4 ]向量。

模型新建好之后就可以开始训练模型,为了动态地显示模型训练的效果,在程序中约定每训练1000个epoch,就对测试数据进行一次预测,并将预测的误差及预测的输出值可视化展示。

多项式回归的拟合过程如图2.23所示。

图2.23 多项式回归的拟合过程

经过大量的迭代之后,模型能够非常好地拟合到测试数据,这说明多项式回归的拟合能力非常强。读者可以自行拟合如下所示的心形函数图像。

心形函数图像如图2.24所示。

图2.24 心形函数 f ( x )=13 cos( x )-5 cos(2 x )-2 cos(3 x )-cos(4 x )

下面介绍和神经网络具有千丝万缕联系的逻辑回归。

2.4.3 PyTorch实现逻辑回归

逻辑回归是非常经典的分类算法,适用于分类任务,如垃圾分类任务、情感分类任务等都可以使用逻辑回归。

逻辑回归发展自Logistic分布,它的累积分布函数和密度函数如下:

μ 影响中心对称点的位置, γ 越小表示中心点附近的增长速度越快。通常,在机器学习和神经网络中用到的Logistic分布函数 μ 的取值为0,而 γ 的取值为1。Sigmoid是一种特殊的函数(“S”形曲线函数),其图像如图2.25所示。

图2.25 Sigmoid(x)= 的图像

从图2.25可以看出,Sigmoid函数的图像关于(0,0.5)中心对称,在中心点附近变化得较快,当| x |≥6时,曲线基本上没有什么变化,以非常微小的速度无限接近0.0和1.0。那么Logistic分布和Logistic回归究竟有什么联系?

以二分类为例,假设输入特征向量 x R n 是线性可分的数据,我们总能找到一个超平面将两个类别的数据分开(可以通过移动截距项 b 及合适的权重 w 找出这个超平面),假设超平面方程为 ,对于这个超平面,令 f ( x )>0的为正类, f ( x )<0的为负类,这就形成了一个最简单的感知机模型,决策边界便是 f ( x ),是否可以在表示出分类的同时还能表达其概率的大小?从超平面的角度来看, f ( x )>0越大,是正类的概率就越大; f ( x )<0越小,是负类的概率就越小。而Sigmoid函数正好可以映射到一个概率,并且Sigmoid函数的性质非常符合“数值越大,概率越大;数值越小,概率越小”,因此早期的研究人员便采用了Sigmoid分布的累积分布函数的特殊形式,即Sigmoid函数。

通过Sigmoid函数可以直接将输入特征 x R n 与对应类别的概率直接联系起来,其表达形式如下:

这是通过 p ( y =0| x )+ p ( y =1| x )=1得到的。接下来引入一个叫作“几率”(Odds)的概念,表示一个事件发生的概率与不发生的概率的比值。例如,一个时间发生的概率为 p ,则几率为 ,取对数便可得到对数几率 。借助这个概念,可以得出 p ( y =0| x )的几率:

取对数几率为

log Odds( y =0| x )=log e wx + b = wx + b

从对数几率表达式可以看出,输出数据 Y =0的对数几率是输入数据 x 的线性函数。线性函数是连续的回归函数,这就是Logistic逻辑回归名称的由来。Logistic逻辑回归模型的思路是先拟合决策边界(不一定是线性函数,可以是其他非线性函数),再建立决策边界和概率之间的关系,从而得到不同分类的概率。

接下来使用逻辑回归模型完成一个二分类任务。首先构造训练数据,代码如下所示。

构造的正、负两类数据的可视化效果如图2.26所示。

同前面一样,接下来通过继承nn.Module创建LogisticRegression模型,并在模型中使用Linear构建线性模型。需要注意的是,这里Linear的输入数据为2,表示数据的两个维度( x (1) , x (2) );输出数据为1,表示数据的标签,要么为1,要么为0。这些操作和线性回归类似,逻辑回归中多了一个Sigmoid函数,对 wx + b 的结果做概率判断,这里以0.5为分界线,大于或等于0.5为正类,小于0.5为负类。

图2.26 二分类数据

定义好模型之后就需要创建模型实例、定义损失函数及优化器。因为是二分类任务,所以这里采用BCELoss方法(PyTorch中常见的损失函数及使用场景,在后面章节会详细介绍)和SGD优化算法。下面是训练的代码,为了便于可视化监测训练效果,约定每隔100次进行可视化展示,并绘制出分界线。

随着不断迭代,损失值不断降低,预测精度不断升高,经过50万次迭代后,最终预测精度为0.9753(见图2.27),简单的逻辑回归模型能够轻松地在线性不可分的数据集上达到这个精度,这说明了线性回归的有效性。

其实,逻辑回归和神经网络有很多类似的地方,借助神经网络超强的拟合能力,可以进一步提升上面数据的分类精度,甚至可以接近100%的精度。使用神经网络可以对数据进行分类,对此感兴趣的读者可自行尝试完成。神经网络在后面还会不断出现,因为它是本书要重点讲解的内容。

图2.27 逻辑回归分类 eqYHZ9LVvlIbz4hs01qZsMsMlg/U2+D871myUdZLB0ag2A+AP26uUox8TCm9qBFl

点击中间区域
呼出菜单
上一章
目录
下一章
×