购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

2.3 用PyTorch Lightning实现神经网络实例

PyTorch Lightning是类似于Keras的一个库,是为PyTorch深度学习框架提供高级抽象的轻量级库。它旨在简化训练和推理过程的开发和管理,使用户能够更专注于模型设计和实验。

PyTorch Lightning基于PyTorch,并提供了一组模板代码和工具,使得构建训练循环、日志记录、自动分布式训练、灵活的配置和模型验证等任务更加容易。特别是神经网络,使其更易于理解,同时为创建可扩展的深度学习模型提供了广泛的可能性,这些模型可以很容易地在分布式硬件上运行。

使用PyTorch Lightning,可以避免编写样板代码并重复实现训练循环。只需定义模型、数据加载器和优化器,PyTorch Lightning将会自动处理训练和验证过程。此外,可以通过简单的继承和重写来定制化和扩展功能,使其适应特定的项目需求。

PyTorch Lightning的安装代码如下:

PyTorch Lightning的主要优势如下。

1)高级抽象。PyTorch Lightning提供了一组高级抽象的模块,如LightningModule和LightningDataModule,可以帮助用户快速构建模型和数据加载器。用户只需要定义模型的核心逻辑,而不再需要编写整个训练循环的代码。

2)组织逻辑清晰。PyTorch Lightning通过将训练过程分解为训练步骤、验证步骤和测试步骤等模块,使得代码组织更加清晰。模型训练逻辑放在training_step中,验证逻辑放在validation_step中,测试逻辑放在test_step中。这样的分解可以提高代码的可读性和可维护性。

3)自动优化。PyTorch Lightning提供了自动优化的功能,可以根据用户定义的优化器、损失函数和学习率调度器等进行自动的梯度计算、参数更新和学习率调整。用户只需要在模型中指定相关参数,而不需要手动编写优化和更新过程的代码。

4)分布式训练支持。PyTorch Lightning对分布式训练提供了良好的支持。用户可以通过设置训练器的accelerator参数为ddp、dp或ddp2等来启用分布式训练,框架会自动处理数据并行、模型并行和梯度同步等细节。

5)提供实用工具。PyTorch Lightning提供了一些实用工具,如自动模型检查点保存、可视化训练过程、集成TensorBoard和Comet等,可以简化训练过程的管理和监控。

总之,PyTorch Lightning是一个旨在简化和加速PyTorch模型训练流程的框架,它提供了高级抽象、自动优化、分布式训练支持等功能,帮助用户更加方便地开发和维护深度学习模型。

接下来将2.2节的实例用PyTorch Lightning来实现,使用的数据及网络结构完全一致,只是添加了保存及恢复训练模型的部分代码。

(1)导入必要的模块

导入pytorch-lighting(简称为pl)模块及Trainer。

(2)定义LightningModule

构建模型的步骤和优化器的配置是由模型类中的方法来定义的,而不是编写循环,在这种情况下,可以定义数据加载程序的工作。PyTorch Lightning是一个非常灵活的工具。

其中,def_init_(self)用于定义网络架构;def forward(self, x)用于定义推理、预测的正向传播;def training_step(self, batch, batch_idx)用于定义训练循环部分;def configure_optimizers(self)用于定义优化器。Lightning Module定义的是一个系统而不是单纯的网络架构。

(3)创建pl.Trainer对象

接下来创建一个pl.Trainer对象,用于配置训练器的参数,如使用的GPU数量、最大训练轮数以及模型保存回调等。最后,通过调用trainer.fit()和trainer.test()等方法开始训练和测试流程。

trainer是自动化的,包括:

● 循环迭代Epoch and batch iteration。

● 自动调用optimizer.step()、backward、zero_grad()。

● 自动调用.eval()、enabling/disabling grads。

● 权重加载、保存模型及日志等。

● 支持单机多卡、多机多卡等方式。

(4)装载模型

装载保存的模型代码如下:

运行结果如图2-6所示。

图2-6 装载保存的模型测试结果

(5)把PyTorch转换为PL格式的代码

前面介绍了PyTorch两种编码方式,一种是较详细的编程方法,另一种是类似Keras的简约方式,图2-7说明如何把PyTorch代码转换为PL格式的代码。

图2-7 把PyTorch转换为PL代码的对应关系 8VPcjUdElumB3S++GlZMKiJupzKmNmb5jmGNypvRNqqToY78D/yIx9v6V7bJjsVZ

点击中间区域
呼出菜单
上一章
目录
下一章
×