购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

2.1 认识TensorFlow数据流图

2.1.1 数据流图简介

TensorFlow是通过数据流图来完成数据处理的。可以借助TensorFlow API来实现数据流图。

数据流图是一张神经网络(图2-1),它能模拟人类大脑处理信息的过程,由节点(神经元)、连线(神经突触)和流经的张量(数据信息)三部分组成。节点代表对数据的某种处理,也就是对张量的运算操作。连线代表节点与节点之间的连接、依赖关系。要传递、处理的张量在编程代码中用数组或列表来表示。

读者可以试着想象一个动态的数据流图,若干张量通过连线汇聚到一个节点,按这个节点处的要求完成某种运算操作,这个节点的输出张量再通过连线流向其他节点。经过若干节点运算,张量最终流向输出节点,完成整个运算过程。

图2-1 神经网络

以图2-2为例,用数据流图完成如下简单计算过程。

已知输入节点“input”处的输入张量a为[5,2]。张量a流经上方的节点“add”,求出a的所有元素之和,作为张量b;与此同时,张量a流经下方的节点“max”,求出张量a中所有元素的最大值,作为张量c。张量b和张量c汇聚到节点“mul”处,求出b和c的乘积,作为张量d,这就是最终输出结果。也就是说,输入张量a流经此数据流图的运算结果即张量d。

图2-2 数据流图示例

2.1.2 实现数据流图

接下来看一看如何用代码实现上节中的数据流图。使用TensorFlow实现数据流图需要两个步骤:定义它、运行它。下面利用TensorFlow API来实现该数据流图。

首先打开PyCharm,按照1.7节中的步骤新建项目。在项目名称的右键快捷菜单中选择“New”→“Python File”命令,完成Python文件的创建,然后输入如下代码。

在代码编辑区右击并选择运行命令,可以看到代码运行结果为35。

2.1.3 数据流图代码解析

接下来解析上节中代码的含义。

1.导入TensorFlow库

这行代码的作用是导入TensorFlow库,并赋予它一个新的名字tf。这样,在后面用到TensorFlow库函数时就不用输入完整的名称,省时又省力。

2.创建数据流图

这部分代码定义了4个节点的运算操作。输入节点处的张量a是通过tf.constant()方法创建的常量形式的一维列表,元素值为5、2。在节点“add”处,通过tf.reduce_sum()方法对a中的所有元素求和,得到张量b。在节点“max”处,通过tf.reduce_max()方法获取a中元素的最大值,作为张量c。在最后一个节点处,用tf.multiply()方法算出张量b和张量c的乘积,作为数据流图最后的输出张量d。

需要注意的是,这里只有一个输入节点,输入的张量为一维列表[5, 2]。相比于采用两个输入节点,输入值分别为标量5和标量2,张量输入有如下优点。

(1)只需要将输入数据送给单个节点,简化了数据流图。

(2)后面的节点只依赖于一个输入节点而非多个。

(3)输入节点可以接收任意维度的列表(张量)。

3.运行数据流图

通过前面的代码定义了一个数据流图,完成了一个简单神经网络的创建。要想得到输出结果,还要将数据流图运行起来。

这里用构造方法tf.Session()创建了一个会话sess,通过sess.run(d)运行数据流图,让张量在数据流图的节点之间流动,最终求出张量d并打印出结果。

提示:其他节点的张量也可以通过sess.run()方法进行运算。试着添加下面这行代码,运行数据流图求出张量c并打印出结果。

4.run()方法详解

run()方法如下所示:

该方法有4个参数,其中fetches和feed_dict是需要重点关注的参数。参数的值为None,意味着该参数可以缺省。run()方法至少要有fetches这一个参数。重点参数的解释如下。

(1)fetches。

该参数可以是任意数据流图中的元素,如某个节点处的张量。

(2)feed_dict。

该参数用于覆盖数据流图中的张量值,其格式为Python字典对象。字典包含“Key”和“Value”两个元素,也就是程序员所熟悉的键值对。“Key”指向需要被替换的张量;“Value”的值用于替换原来的张量值,只要数据类型相同即可替换。

feed_dict非常重要,它可以用来更新、替换、覆盖输入张量的值,后续章节会对它进行详细介绍。 0YrcvjGz4PPukvVvzQbBbZxRipxuQOosS1nJ5b34fQvWuQXT4Fn+9fcjH/3FZhuY

点击中间区域
呼出菜单
上一章
目录
下一章
×