购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

2.2 安装PyTorch 2.0

Python运行环境调试完毕后,接下来的任务便是安装本书的核心组件——PyTorch 2.0。PyTorch作为当下热门的深度学习框架,为研究者和开发者提供了灵活且高效的工具来构建和训练神经网络。其2.0版本的推出,更是带来了诸多新特性和性能优化,进一步提升了用户体验。

2.2.1 NVIDIA 10/20/30/40系列显卡选择的GPU版本

目前市场上有NVIDIA 10/20/30/40系列显卡,对于需要调用专用编译器的PyTorch来说,不同的显卡需要安装不同的依赖计算包。作者在此总结了不同显卡的PyTorch版本以及CUDA和cuDNN的对应关系,如表2-1所示,推荐读者使用20及以上系列的显卡。

表2-1 NVIDIA 10/20/30/40系列显卡的版本对比

注意

这里的区别主要在于显卡运算库CUDA与cuDNN的区别,当在20/30/40系列显卡上使用PyTorch时,可以安装CUDA11.6版本以上以及cuDNN8.1版本以上的库。而在10系列版本的显卡上,建议优先使用2.0版本以前的PyTorch。

下面以PyTorch 2.0为例,演示完整的CUDA和cuDNN的安装步骤,不同版本的安装过程基本一致。

2.2.2 PyTorch 2.0 GPU NVIDIA运行库的安装

本小节讲解PyTorch 2.0 GPU版本的前置软件的安装。对于GPU版本的PyTorch来说,由于调用了NVIDIA显卡作为其代码运行的主要工具,因此额外需要NVIDIA提供的运行库作为运行基础。

我们选择PyTorch 2.0.1版本进行讲解。对于PyTorch 2.0的安装来说,最好的方法是根据官方提供的安装命令进行安装,具体参考官方文档https://pytorch.org/get-started/previous-versions/。从页面上可以看到,针对Windows版本的PyTorch 2.0.1,官方提供了几种安装模式,分别对应CUDA 11.7、CUDA 11.8和CPU only。使用conda安装的命令如下:

下面以CUDA 11.8+cuDNN 8.9为例讲解安装的方法。

(1)首先是CUDA的安装。在百度搜索CUDA 11.8 download,进入官方下载页面,选择适合的操作系统安装方式(推荐使用exe(local)本地化安装方式),如图2-15所示。

图2-15 CUDA 11.8下载页面

此时下载下来的是一个EXE文件,读者可自行安装,不要修改其中的路径信息,完全使用默认路径安装即可。

(2)下载和安装对应的cuDNN文件。要下载cuDNN,需要先注册,相信读者可以很快完成,之后直接进入下载页面,如图2-16所示。

注意

不要选择错误的版本,一定要找到对应CUDA的版本号。另外,如果使用的是Windows 64位的操作系统,需要下载x86_64版本的cuDNN。

图2-16 cuDNN 8.9下载页面

(3)下载的cuDNN是一个压缩文件,将它解压并把所有的目录复制到CUDA安装主目录中(直接覆盖原来的目录)。CUDA安装主目录如图2-17所示。

图2-17 CUDA安装主目录

(4)确认PATH环境变量,这里需要将CUDA的运行路径加载到环境变量的PATH路径中。安装CUDA时,安装向导能自动加入这个环境变量值,确认一下即可,如图2-18所示。

图2-18 将CUDA路径加载到环境变量PATH中

(5)最后完成PyTorch 2.0.1 GPU版本的安装,只需在Miniconda Prompt窗口中执行本小节开始给出的PyTorch安装命令即可。

     # CUDA 11.8
     conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8
-c pytorch -c nvidia

2.2.3 Hello PyTorch

至此,我们已经完成了PyTorch 2.0的安装。下面使用PyTorch 2.0做一个小练习——Hello PyTorch。打开Miniconda Prompt窗口,执行python命令并依次输入如下命令,验证安装是否成功。

import torch
result = torch.tensor(1) + torch.tensor(2.0)
result

结果如图2-19所示。

图2-19 验证安装是否成功

或者打开前面安装的PyCharm IDE,先新建一个项目,再新建一个hello_pytorch.py文件,输入如下代码:

import torch
result = torch.tensor(1) + torch.tensor(2.0)
print(result)

最终结果请读者自行验证。 Pg3BHjHVv/+SLNfxMxYJ/oyouD/7PdiQNiJKMdzlLFgM6gh00xQMYbUbR9N9//rx

点击中间区域
呼出菜单
上一章
目录
下一章
×

打开