购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

2.2 用PyTorch实现神经网络实例

前面介绍了使用PyTorch构建神经网络的一些组件、常用方法和主要步骤等,本节通过利用神经网络对手写数字进行识别的实例,来说明如何借助nn工具箱来实现一个神经网络,并对神经网络有一个直观的了解。在这个基础上,后续我们将对nn的各模块进行详细介绍。实例环境使用PyTorch 2.0,使用GPU或CPU,源数据集为MNIST。主要步骤如下。

● 利用PyTorch内置函数MNIST下载数据。

● 利用torchvision对数据进行预处理,调用torch.utils建立一个数据迭代器。

● 可视化源数据。

● 利用nn工具箱构建神经网络模型。

● 实例化模型,并定义损失函数及优化器。

● 训练模型。

● 可视化结果。

神经网络的结构如图2-3所示。

图2-3 神经网络结构

使用两个隐含层,每层使用ReLU激活函数,输出层使用softmax激活函数,最后使用torch.max(out,1)找出张量输出最大值对应索引作为预测值。

2.2.1 准备数据

1)导入必要的模块。

2)定义一些超参数。

3)下载数据并对数据进行预处理。

说明:

● transforms.Compose可以把一些转换函数组合在一起。

● Normalize([0.5], [0.5])对张量进行归一化,这里两个0.5分别表示对张量进行归一化的全局平均值和方差。因图像是灰色的,则只有一个通道,如果有多个通道,需要有多个数字,如三个通道,应该是Normalize([m1,m2,m3], [n1,n2,n3])。

● download参数控制是否需要下载,如果./data目录下已有MNIST,可选择False。

● 用DataLoader得到生成器,可节省内存。

● torchvision及data为PyTorch的数据预处理工具。

2.2.2 可视化源数据

对数据集中的部分数据进行可视化,代码如下:

运行结果如图2-4所示。

图2-4 MNIST源数据示例

2.2.3 构建模型

数据预处理之后,我们开始构建网络来创建模型。

(1)构建网络

(2)实例化网络

2.2.4 训练模型

这里使用for循环进行迭代来训练模型。首先,用训练数据来训练模型,然后用测试数据来验证模型的准确性。

(1)训练模型

最后5次迭代的结果如下:

这个神经网络的结构比较简单,只用了两层,且没有使用dropout层,迭代20次,测试准确率达到98%左右,效果还不错,但仍有提升空间。如果采用cnn、dropout等层,应该还可以提升模型性能。

(2)可视化训练及测试损失值

运行结果如图2-5所示。

图2-5 MNIST数据集训练的损失值 XLgYsQp+uQdoLD94IgZF2Pq6PT5fkgQbLdngJocgyyRPlHF/G3I+hmk2uVoKicu4

点击中间区域
呼出菜单
上一章
目录
下一章
×