购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

第二章

K近邻算法

在本章中,我们将要学习K近邻算法的思想和原理,了解图像相似度的计算方法,并使用sklearn库实现K近邻算法,最后我们会简要分析K近邻算法的缺点。

一、算法思想

顾名思义,K近邻(K-Nearest Neighbors)算法就是找到K个最近的邻居,即给定一个已知类别的数据集,对一个新的数据,在已知数据集中找到与该数据最邻近的K个数据,这K个数据中的多数属于哪个类,就把该数据分到这个类中。这就类似于现实生活中少数服从多数的思想。下图是一个K近邻算法分类的示例,蓝色、黄色与黑色是已知的分类。如果出现一个新的数据点(红色正方形),首先,我们要计算它与其他数据的相似度。对于平面上的点,我们用点之间的距离来表示相似度,距离越近,相似度越高。然后,我们找到与红色相似度最高的一些点,如果我们设定K=3,那么就要找到距离新的数据点最近的3个点。发现最近的3个点中,有2个黄色和1个黑色,最后,我们将其分类为黄色。

K近邻算法图解

K近邻算法中K的取值问题没有完美的解答,它往往是由经验决定的。在机器学习算法中,常常有一些参数需要算法设计者指定,它们被称为超参数。这里的K就是K近邻算法里的超参数。下面我们看一下K的不同取值对算法结果产生的影响。

下页图展示了K取不同值时的分类结果。图中红色和蓝色的圆圈是已知类别的数据点,背景为红色则代表这片区域的点会被分类为红色,背景为蓝色则被分为蓝色。当K比较小的时候,我们只会对数据点周围很小的一片区域感兴趣,而忽略其他数据。由于我们只关注很小的一片区域,受随机性的影响就会比较大。如左图中,蓝色区域有一小块红色,这是因为这里落入了零星几个红色点,这些红色点可能是由于噪音或其他错误导致的特例。但是因为这些特例的存在,会导致它们周围一片区域都被划分为红色。而事实上,我们可能更希望它们被分为蓝色,因为显然周围蓝色点的数量更多。而当K比较大的时候,我们关注的区域较大,受随机性的影响减小,数据点的微小变化不会引起分类的变化,但与此同时,我们可能会忽略掉某些局部细节。

不同的K产生的分类界限

我们再来看一个例子。下图是使用K近邻算法进行手写数字识别的示例。图像右侧显示了与左侧手写数字最相似的3张已知图像。如果选择K=1,那么左边的手写数字就会被错误地分类为5,如果选择K=3,那么由于最近的3个邻居中有2个3,它会被正确地分类为3。在实践中,我们往往会考虑将K取值为3或5。

K近邻手写数字识别示例

二、图像相似度

通过“一、算法思想”中的例子,我们了解到平面中两个点之间的相似度可以用它们的距离表达,距离越小代表相似度越高。点之间的距离可以通过勾股定理得到,平面上a和b两个点的距离可以由下面的公式计算得出:

其中a1,b1是a,b两点的x轴坐标,a2,b2是a,b两点的y轴坐标。

同理,图像间的相似度也可以利用图像间的距离进行表达。我们知道一张灰度图是由一个矩阵来表示的,那灰度图之间的相似度,则可以用两矩阵间的距离表示。计算两个矩阵之间的距离与计算两个点的距离十分类似,只需要将矩阵对应位置的数字相减(所谓对应位置就是矩阵A的第一行第一列对应矩阵B的第一行第一列,矩阵A的第一行第二列对应矩阵B的第一行第二列),将对应数字相减后的平方求和,再取根号,即为两个矩阵间的距离。这个距离又被称为两个矩阵间的欧式距离。我们用Aij来代表矩阵第i行第j列的数字,矩阵A和B之间的欧式距离下面的公式计算:

两张图像对应矩阵的欧式距离越小,那么它们就越相似。

思考与实践

2.1 用这种方式度量两张图像的距离好吗?如果不好,你能举例说明吗?

三、使用sklearn库中的K近邻算法

了解了K近邻算法的原理以及如何计算图像相似度后,我们就可以开始编写代码实现一个简单的手写数字分类器了。这一节我们主要用到sklearn库。

【延伸阅读】

sklearn库

sklearn又称scikit-learn,是一个专注于机器学习任务的库。它包含很多常用的分类、回归和聚类算法,以及机器学习任务中常用的工具和经典的数据集。使用sklearn可以简单高效地进行数据预处理、数据分析和训练机器学习算法。sklearn为不同的算法提供了统一的训练和预测接口,非常方便上手。

首先,导入一些必要的库。

KNeighborsClassifier:sklearn库中提供的K近邻算法类。

mnist:Keras库中包含的一个方便下载MNIST数据集的类。

pyplot:Python中常用的画图工具matplotlib的画图类。

numpy:Python中用于处理各种数值计算的库。

利用Keras的mnist的模块加载MNIST数据集。KNeighborsClassifier要求输入数据是向量形式的,但是原输入数据是二维的图像,因此,我们使用reshape函数将二维的图像展开成一维的向量。

输出结果显示reshape后的训练数据的输入数据的数据形状为60000个784(28×28)维的向量。

接下来调用sklearn中的K近邻算法。K近邻算法会接受一个参数n_neighbors,也就是查看邻居的个数K。我们设定K的值为5,即每次寻找最近的5个邻居。

使用手写数字数据训练K近邻分类器。sklearn库提供了一个非常方便的函数fit,它需要两个参数,输入数据和对应的类别。我们把x_train和y_train作为参数传给函数fit,即可开始K近邻分类器的训练(训练过程需要花一些时间)。

训练完毕!我们使用predict方法对测试集中的10000个数据进行分类,并将分类结果存到y_predict中(预测也需要花一些时间)。

将分类结果y_predict与真实的类别y_test进行一一对比,统计分类正确的个数,并计算分类的准确度。

我们仅仅用了不到10行代码,就得到了一个分类准确度为96.88%的手写数字分类器!在机器学习领域,有非常多的个人和组织会将已有的算法实现并打包成库,供他人下载,降低使用算法的门槛。找到并使用这些库,是每个对机器学习感兴趣的人的必备技能。

刚才的10000张图像中,分错了322张,我们可以使用以下代码显示分错类的图像。

分类错误的图像

y_predict是K近邻算法的分类结果,y_test是正确分类。从中可以看出,分错类的图像基本都是一些歪歪扭扭,与正常写法有较大差异的数字。

四、K近邻算法的缺点

K近邻算法的原理非常简单易懂,效果也不错,但是K近邻算法也有一些明显的缺点。

1.分类速度慢。如果你运行一遍代码就会发现,在predict过程中,K近邻算法花费了大量时间。这是因为在每做一次分类时,K近邻算法都要在已知的数据中找到最相似的那些数据,当已知数据规模很大的时候,这个寻找的过程就非常慢。但一般情况下我们又希望已知数据越多越好,否则算法的准确性会降低。虽然有一些索引算法可以加速寻找过程,但总体来说,每做一次分类花费的时间还是不可接受。因此,K近邻算法不适用于一些需要快速出结果的场景。

2.分类效果依赖于相似度计算的方法。K近邻算法的准确性很大程度依赖于相似度计算方法的准确性,我们目前采用的是简单的欧式距离。这种方法对于规整的黑白手写数字分类效果尚可,但是对于复杂的图像,比如歪扭、不规整的图像或是CIFAR—10中的彩色物体图像,效果就会大打折扣。事实上,我们只要将图像旋转一个角度,目前的相似度计算方法就会出现问题。如下图所示,如果将数字7的图像旋转一个很小的角度,然后将其与未旋转的数字图像叠放在一起,淡灰色的部分就是它们不相同的地方。这时两张图像的欧式距离就不足以反应这两张图像内容上的相似性。

将图像旋转一定角度后产生的差异 5bgNU0ceNnvw5LoiHwFN2rAC2W5tksbocAKGni1NznC2hCKYSFN/z8EODblbA6M1

点击中间区域
呼出菜单
上一章
目录
下一章
×