购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

2.4 图像降噪:手把手实战第一个深度学习模型

2.3节的程序读者可能感觉过于简单,直接调用库,再调用模型及其方法,即可完成所需要的功能。然而真正的深度学习程序设计不会这么简单,为了给读者建立一个使用PyTorch进行深度学习的总体印象,在这里准备了一个实战案例,手把手地演示进行深度学习任务所需要的整体流程,读者在这里不需要熟悉程序设计和编写,只需要了解整体步骤和每个步骤所涉及的内容即可。

2.4.1 MNIST数据集的准备

HelloWorld是任何一种编程语言入门的基础程序,任何一位初学者在开始编程学习时,打印的第一句话往往就是HelloWorld。在深度学习编程中也有其特有的“HelloWorld”,一般指的是采用MNIST完成一项特定的深度学习项目。

对于好奇的读者来说,一定有一个疑问,MNIST究竟是什么?

实际上,MNIST是一个手写数字图片的数据集,它有60 000个训练样本集和10 000个测试样本集。打开后,MNIST数据集如图2-24所示。

图2-24 MNIST数据集

读者可直接使用本书配套源码中提供的MNIST数据集,保存在dataset文件夹中,如图2-25所示。

图2-25 本书配套源码中提供的MNIST数据集

之后使用NumPy数据库进行数据读取,代码如下:

     import numpy as np
     x_train = np.load("./dataset/mnist/x_train.npy")
     y_train_label = np.load("./dataset/mnist/y_train_label.npy")

读者也可以在百度搜索MNIST,直接下载train-images-idx3-ubyte.gz、train-labels-idx1-ubyte.gz等4个文件,如图2-26所示。

图2-26 MNIST文件中包含的数据集

下载这4个文件并解压缩。解压缩后可以发现这些文件并不是标准的图像格式,而是二进制格式,包括一个训练图片集、一个训练标签集、一个测试图片集以及一个测试标签集。其中训练图片集的内容如图2-27所示。

图2-27 MNIST文件的二进制表示(部分)

MNIST训练集内部的文件结构如图2-28所示。

图2-28 MNIST文件结构图

如图2-26所示是训练集的文件结构,其中有60 000个实例。也就是说这个文件包含60 000个标签内容,每个标签的值为一个0~9的数。这里我们先解析每个属性的含义。首先,该数据是以二进制格式存储的,我们读取的时候要以rb方式读取;其次,真正的数据只有[value]这一项,其他的[type]等只是用来描述的,并不真正在数据文件中。

也就是说,在读取真实数据之前,要读取4个32位整数。由[offset]可以看出,真正的像从0016开始,每个像素占用一个int 32位。因此,在读取像素之前,要读取4个32位整数,也就是magic number、number of images、number of rows和number of columns。

结合图2-26的文件结构和图2-25的原始二进制数据内容可以看到,图2-25起始的4字节数0000 0803对应图2-26中列表的第一行,类型是magic number(魔数),这个数字的作用为文件校验数,用来确认这个文件是不是MNIST里面的train-images-idx3-ubyte文件。而图2-25中的0000 ea60对应图2-26图列表的第二行,转化为十进制为60000,这是文件总的容量数。

下面依次对应。图2-25中从第8个字节开始有一个4字节数0000 001c十进制值为28,也就是表示每幅图片的行数。同样地,从第12个字节开始的0000 001c表示每幅图片的列数,值也为28。而从第16个字节开始则是依次每幅图片像素值的具体内容。

这里使用每784(28×28)字节代表一幅图片,如图2-29所示。

图2-29 每个手写体被分成28×28个像素

2.4.2 MNIST数据集的特征和标签介绍

对于数据库的获取,前面介绍了两种不同的MNIST数据集的获取方式,本小节推荐使用本书配套源码包中的MNIST数据集进行数据的读取,代码如下:

     import numpy as np
     x_train = np.load("./dataset/mnist/x_train.npy")
     y_train_label = np.load("./dataset/mnist/y_train_label.npy")

这里numpy库函数会根据输入的地址对数据进行处理,并自动将其分解成训练集和验证集。打印训练集的维度如下:

     (60000, 28, 28)
     (60000, )

这是进行数据处理的第一步,有兴趣的读者可以进一步完成数据的训练集和测试集的划分。

回到MNIST数据集,每个MNIST实例数据单元也是由两部分构成的,分别是一幅包含手写数字的图片和一个与其相对应的标签。可以将其中的标签特征设置成y,而图片特征矩阵以x来代替,所有的训练集和测试集中都包含x和y。

图2-30用更为一般化的形式解释了MNIST数据实例的展开形式。在这里,图片数据被展开成矩阵的形式,矩阵的大小为28×28。至于如何处理这个矩阵,常用的方法是将其展开,而展开的方式和顺序并不重要,只需要将其按同样的方式展开即可。

图2-30 图片转换为向量模式

下面回到对数据的读取,前面已经介绍了,MNIST数据集实际上就是一个包含着60 000幅图片的60 000×28×28大小的矩阵张量[60000,28,28],如图2-31所示。

图2-31 MNIST数据集的矩阵表示

矩阵中行数指的是图片的索引,用以对图片进行提取,而后面的28×28个向量用以对图片特征进行标注。实际上,这些特征向量就是图片中的像素点,每幅手写图片是[28,28]的大小,每个像素转化为一个0~1的浮点数,构成矩阵。

2.4.3 模型的准备和介绍

对于使用PyTorch进行深度学习的项目来说,一个非常重要的内容是模型的设计,模型用于决定在深度学习项目中采用哪种方式完成目标的主体设计。在本例中,我们的目的是输入一幅图像之后对其进行去噪处理。

对于模型的选择,一个非常简单的思路是,图像输出的大小就应该是输入的大小,在这里选择使用Unet(一种卷积神经网络)作为设计的主要模型。

注意: 对于模型的选择现在还不是读者需要考虑的问题,随着你对本书学习的深入,见识到更多处理问题的方法后,对模型的选择自然会心领神会。

我们可以整体看一下Unet的结构(读者目前只需要知道Unet的输入和输出大小是同样的维度即可),如图2-32所示。

图2-32 Unet的结构

可以看到,对于整体模型架构来说,其通过若干模块(block)与直连(residual)进行数据处理。这部分内容在后面的章节会讲到,目前读者只需要知道模型有这种结构即可。Unet模型的整体代码如下:

上面倒数第1~3行的代码段表示只有在本文件作为脚本直接执行时才会被执行,而在本文件import到其他脚本中(代码重用)时这段代码不会被执行。

2.4.4 对目标的逼近——模型的损失函数与优化函数

除了深度学习模型外,要完成一个深度学习项目,另一个非常重要的内容是设定模型的损失函数与优化函数。初学者对这两部分内容可能不太熟悉,在这里只需要知道有这部分内容即可。

首先是对于损失函数的选择,在这里选用MSELoss作为损失函数,MSELoss函数的中文名字为均方损失函数。

MSELoss的作用是计算预测值和真实值之间的欧式距离。预测值和真实值越接近,两者的均方差就越小,均方差函数常用于线性回归模型的计算。在PyTorch中,使用MSELoss的代码如下:

     loss = torch.nn.MSELoss(reduction="sum")(pred, y_batch)

下面是优化函数的设定,在这里采用Adam优化器。对于Adam优化函数,请读者自行查找资料学习,在这里只提供使用Adam优化器的代码,如下所示:

     optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

2.4.5 基于深度学习的模型训练

前面介绍了深度学习的数据准备、模型、损失函数以及优化函数,本小节使用PyTorch训练出一个可以实现去噪性能的深度学习整理模型,完整代码如下(代码文件参看本书配套代码):

在这里展示了完整的模型训练过程,首先传入数据,然后使用模型对数据进行计算,计算结果与真实值的误差被回传到模型中,最后PyTorch框架根据回传的误差对整体模型参数进行修正。训练流程如图2-33所示。

图2-33 训练流程

从图2-33中可以很清楚地看到,随着训练的进行,模型逐渐学会对输入的数据进行整形和输出,此时从输出结果来看,模型已经能够很好地对输入的图形细节进行修正,读者可以自行运行代码测试一下。 UGni2RKHHFdV0R+Q+ioTliqb3lPK34JBQ9lbCw8hlMvFkLA5Z6TfH1oHvX0BbhzT

点击中间区域
呼出菜单
上一章
目录
下一章
×