Flower是由英国剑桥大学等组织提出的一个新颖的端到端轻量级的联邦学习框架,它可以无缝地从模拟实验研究过渡到对大量真实边缘设备的系统研究。Flower在两个领域(即模拟和现实世界设备)提供了各自的优势,并提供实验(实现在探索和开发过程中,根据需要在两个极端之间迁移的能力)。本节将介绍Flower的系统架构、框架安装和算法开发流程,以及一个实战示例。
服务器端的事务逻辑,包括客户端选择、实验配置、参数更新聚合,以及全局和本地模型的评估,都可以通过“策略(Strategy)”抽象类来表示。“策略”抽象类的实现表示一个联邦学习算法,Flower提供了一些流行且经过测试的算法实现,例如FedAvg和FedYogi等。Flower核心框架实现了大规模运行这些工作负载所需的基础设施。Flower框架的服务器端主要涉及三个组件:客户端管理器(Client Manager)、联邦学习循环(FL Loop),以及用户定制的联邦学习策略,如图2-4所示。服务器组件从客户端管理器中采样选择客户端,它管理一组客户端代理(Client Proxy)对象,每个对象代表一个连接到服务器的客户端,负责发送和接收真实的客户端的Flower协议信息通信。联邦学习循环协调整个联邦学习过程,但是它不会决定如何进行,因为这些决定被委托给用户定制的联邦学习策略的配置来实现。
● 图2-4 Flower框架架构图
本质上,联邦学习可以描述为全局计算(模型聚合)和局部计算(模型训练)之间的相互配合。服务器端负责执行全局计算,以及编排全局模型在一组可用客户端上的训练过程。客户端则负责进行局部计算,即通过本地数据进行本地模型训练。Flower的核心框架架构(见图2-4)反映了这些思想,使得开发者或者研究人员可以通过类似搭积木的形式(包括服务器和客户端)进行开发或实验最新的研究学习方法。
也就是说,联邦学习循环要求联邦学习策略配置下一轮训练,将这些配置发送到对应的客户端,并且从客户端接收经过客户端本地训练得到的模型更新(或者故障报告),最后将模型聚合委托给联邦学习策略。至于客户端就更加简单,它只需要等待来自服务器的消息,然后通过调用用户提供的训练和评估函数,对收到的消息做出反馈。
Flower框架内置了虚拟客户端引擎(Virtual Client Engine),它可以实现Flower客户端的虚拟化,以最大限度地利用可用硬件。在给定客户端池,各自的计算和内存预算(例如CPU数量、VRAM要求等),以及特定的联邦学习超参数(例如每轮的客户端数量),虚拟客户端引擎以资源感知(Resource-Aware)的方式启动Flower客户端,并以对用户和Flower服务器端透明的方式调度、实例化和运行Flower客户端。这一属性极大地简化了工作负载的并行化,确保可用硬件被充分利用,并且无须重新配置即可将相同的联邦学习实验移植到各种设置:台式机、单个GPU机器或多节点GPU集群。
Flower是一套开源、可扩展、与框架和设备无关的联邦学习框架,支持不同的深度学习框架,包括TensorFlow和PyTorch。一些适用于轻量级联邦学习工作负载的设备(例如树莓派等)需要最少的配置或不需要特殊配置。一方面,一些支持Python的嵌入式设备可以很容易地用作Flower客户端;另一方面,像智能手机等设备则需要专业的计算芯片完成机器学习负载。为了克服这个限制,Flower框架提供了一个通过直接在客户端处理Flower协议消息的低层级集成方式,即边缘客户端引擎(Edge Client Engine)。
在这一小节将简单介绍一下基于Pytorch和Flower框架的联邦学习算法开发流程。在上一小节中,介绍了联邦学习算法主要包括两个部分:客户端本地模型训练和服务器端模型聚合。
首先,安装Flower框架。安装Flower之前,需要先确保机器安装了3.7或更高版本的Python。使用下面的第一行命令直接安装稳定版本的Flower,也可以使用下面的第二行命令直接从GitHub上安装最新版本的Flower:
然后介绍一下完成联邦学习任务应该如何部署Flower,一共有哪些步骤。Flower的联邦学习系统主要由一个服务器和多个客户端组成。在联邦学习的过程中,服务器将全局模型参数发送给客户端,客户端使用从服务器接收到的参数来更新本地模型。然后它在本地数据上训练模型(在本地更改模型参数)并将更新的模型参数发送回服务器。在客户端,需要通过两个辅助函数,使用从服务器接收到的模型参数来更新本地模型,并从本地模型获取更新后的模型参数:set_parameters和get_parameters。以PyTorch为例,下面是这两个函数的具体实现:通过state_dict函数访问PyTorch的模型参数张量,然后将它们与NumPy(Flower支持对NumPy的序列化)进行相互转换(如代码2-5所示)。
代码2-5 Flower设置Pytorch模型参数
在Flower中,通过实现flwr.client.Client或flwr.client.NumPyClient的子类来创建客户端(如代码2-6所示)。接下来以flwr.client.NumPyClient实现一个Flower的客户端。其中,需要实现三个方法get_parameters、fit和evaluate。其中fit函数从服务器接收模型参数,在本地数据上训练模型参数,并将(更新的)模型参数返回给服务器;evaluate函数从服务器接收模型参数,在本地数据上评估模型参数,并将评估结果返回服务器。fit和evaluate函数中对应的训练和测试(train和test)函数需要用户定义,在之后的实战示例中,可以看到一个基础的CIFAR-10数据集合上的例子。
代码2-6 Flower的客户端
用户可以使用三种方式自定义Flower在服务器端编排联邦学习过程的方式。
第一种,使用现有策略,例如FedAvg;第二种,使用回调函数自定义现有策略;第三种,实施新策略。Flower允许通过Strategy的抽象类来控制联邦学习过程。Flower的核心框架中提供了许多内置策略。
启动服务器端,只需要调用flower框架提供的接口(如代码2-7所示),同时,可以定义配置信息,以及联邦学习策略等。其中,配置信息目前只需要定义联邦学习的通信次数。
代码2-7 Flower启动服务器
接下来将通过Flower提供的一个简单嵌入式设备的例子 进行实战演示。在本次实战中,将演示通过三个客户端进行CIFAR-10数据库的联邦学习,共同训练一个ResNet-18神经网络。我们将使用一台Windows计算机作为服务器,一个Raspberry PI、一台Windows计算机,以及一台MacBook笔记本计算机作为三个客户端,如图2-5所示。
● 图2-5 Flower框架实战演练用例
首先,在每台设备上准备训练数据。通过torch.utils.data中的random_split将训练数据分成3份,并得到对应的训练集、验证集和测试集(如代码2-8所示)。
代码2-8 CIFAR10的Pytorch数据集实例
另外,在客户端定义和中心化的机器学习相同的训练和测试函数(如代码2-9所示)。
代码2-9 Flower客户端训练和测试代码示例
定义模型并创建一个客户端实例。在这里使用一个简单的卷积神经网络,以及创建一个fl.client.NumPyClient的子类(代码如2-10所示)。
代码2-10 定义模型和CIFAR10的客户端
接下来开始进行联邦学习过程。首先需要开启一个联邦学习服务器,然后创建客户端实例,并连接到服务器加入联邦学习(如代码2-11所示)。
代码2-11 Flower联邦学习启动服务器和客户端