Optimizing Federated Learning on Non-IID Data with Reinforcement Learning

文章链接:https://ieeexplore.ieee.org/document/9155494/

作者:Hao WangZakhary KaplanDi NiuBaochun Li

发表:IEEE INFOCOM 2020 - IEEE Conference on Computer Communications

截止当前(2021.04.16)被引次数:25

FL的两大挑战

SEMI - 2020 - Applying Deep Reinforcement Learning Techniques in Federated Learning:https://www.youtube.com/watch?v=JlvizFBFCTw

  • 减少开销的可能方法

    • 减少通信轮数:本地更新
    • 减少每一轮传输的信息大小

  • 统计异构性 non-id 数据

    ML算法假设训练数据是iid。 FL算法的训练数据基于non-iid data (non-iid 数据向training中引入了bias,导致更慢的收敛

上面的解决方案带来两个问题:1)如何获得真正的共享数据【因为所有的数据都在本地】;2)实际上增加了更多的通讯开销, 用于下载共享数据

==» client selection,选择一部分设备来减少设备之间的开销。

Abstract

痛点:1) 移动设备有限的网络连接, 使得FL在所有参与设备上并行执行模型更新与聚合不实际;2)non-iid data对FL的收敛与训练速度增加了额外的挑战

本文提出 FAVOR,一种经验驱动的控制(experience-driven control)框架 智能选择客户端设备参与每一轮的联邦学习,以抵消non-iid data引入的偏差,加快收敛速度。

an implicit connection between the distribution of training data on a device and the model weights trained based on those data 发现在这些实验数据中,设备上训练数据的分布与模型权重之间存在 隐含的联系,使得我们可以根据上传的模型权重 to profile the data distribution(来分析设备上的数据分布) ==» states:本地模型权重和共享的全局模型

提出基于dqn的一种机制在每轮通信中选择一个设备集合 来最大化奖励值,促进了验证准确率的增加,并惩罚(减少)了更多通信轮数的使用。

实验: PyTorch,dataset:MNIST,FashionMNIST,CIFAR-10, 与FedAvg算法对比

Introduction

已有研究指出FL的性能,尤其是FedAvg,因为non-iid data的出现而严重下降

FedAvg随机选择一个设备子集合,并将他们的本地模型权重平均后更新全局模型。 从全局来看,随机选择的本地数据集可能不会影响真实数据分布,但一定 引入bias到全局模型更新中。 non-iid data设备之间很大不同,聚合分散模型减慢了收敛继而降低了模型精度

FAVOR, aim to accelerate and stabilize the federated learning process 基于RL通过每个通信轮主动选择最佳的设备集合抵消non-iid data引入的偏差。

DRL for Client Selection

训练DRL智能体的目标是:使FL尽可能快的收敛到目标准确率(target accuracy)。

在此框架中,智能体不必收集 或 检查任何来自移动设备数据样本,只需要传输模型权重 ==» 因此origin FL一样保护了样本级的隐私。 框架只依赖模型权重信息来决定 哪个设备可能对全局模型的提升最大, 因为在设备上的数据分布和在那些数据上执行SGD获得的本地模型权重隐含的联系

The Agent based on Deep Q-Network

考虑到 limited available traces from federated learning jobs, 相比策略梯度方法与actor-critic方法,DQN训练更高效,而且能高效重复利用数据。

  • State

    $s_t = (w_t,(w_t)^{(1)},…,(w_t)^{(N)} )$ , $w_t$ 表示t轮训练后全局模型的权重,$(w_t)^{(k)}$ 表示第k个设备的本地模型权重

    没有引入额外的通信开销给设备, 因为只有设备k被选中作为client训练时,才会更新$w^{(k)}$

    为解决巨大状态空间问题(CNN模型包含百万个权重),采用高效且轻量的 降维技术 。如本节第三部分

  • Action

    client selection可能导致巨大的动作空间$C_K^N$ ,这使得RL training复杂化了。

    ==» a trick :基于DQN每一轮FL训练 智能体从N台设备中只选出一台设备。 DQN智能体学习最优动作值函数$Q^*(s_t,a)$ 的一个近似器(approximator),用于评估从$s_t$开始的最大化预计收益的action 。 ==» 因此动作空间减少为{${1,2,…,N}$ } ,a=i表示选择设备i参与FL训练

    每个动作值 代表智能体在状态$s_t$时选择一个特定动作a 获得的最大化预计收益。然后选择K台设备,每台设备对应一个不同的动作a,因此得到**$Q^*{(s_t,a)}$的top-K values**

  • Reward

    $r_t = \Xi^{(w_t-\Omega)} -1,t = 1,…,T$,其中$w_t$是全局模型在held-out验证集上经过t轮验证后达到的测试精度(testing accuracy),$\Omega$是目标精度(target accuracy),$\Xi$ 是正常数 在测试精度$w_t$下确保$r_t$呈指数式增长。$r_t \in (-1,0], 0\leq w_t \leq \Omega \leq 1$。当$w_t = \Omega 时,$此时$r_t$达到其最大值0。

    训练DQN智能体 来最大化累计折扣奖励的期望 $R=\sum^T_{t=1}\gamma ^{t-1}r_t=\sum^T_{t=1}\gamma^{t-1}(\Xi ^{(w_t-\Omega)}-1)$ ,其中折扣因子 $\gamma \in (0,1]$ 。

    $r_t$中的两个术语 $\Xi^{(w_t - \Omega)}$ 和 $-1$ motivations,

    前者激励智能体选择设备达到更高的测试精度$w_t$,$\Xi$ 用$w_t$控制奖励$r_t$的增长速度。通常,ML训练过程中,模型精度以更慢的pace增长,意味着轮数t增加时,$|w_t-w_{t-1}|$ 下降。因此,我们使用**指数项**来放大随着FL进展到后期的边缘精度增加。在本实验中 $\Xi$ 设置为64。

    后者-1,鼓励智能体以更少的轮数完成训练,因为消耗越多的轮数,智能体获得的累计奖励越少。

Workflow

上图为FAVOR在每一轮用DRL智能体选择设备执行FL的步骤。

  • Step1:FL服务器检查所有N台合格的设备
  • Step2:没太设备从服务器下载初始随机模型权重$w_{init}$,在每个回合执行本地SGD,然后将结果模型权重 ${w_1^{(k)}, k \in [N]}$返回给服务器
  • Step3:在第t轮($t=1,2,…,$),接收到上传的本地权重后,更新存储在服务器上的本地模型权重的对应副本。DQN智能体计算所有设备a=1,…,N的 $Q(s_t,a;\theta)$
  • Step4:DQN智能体选择K台设备对应top-K values,$Q(s_t,a;\theta)$,a=1,…,N。被选的K台设备下载最新的全局模型权重$w_t$,然后在本地执行一轮SGD来获得{${ w_{t+1}^{k} k \in [K] }$ }
  • Step5:上传{${ w_{t+1}^{k} k \in [K] }$ }到服务器,基于FEDAVG计算$w_{t+1}$。进入t+1轮并重复Step3-5

Dimension Reduction

PCA提取两个主成分,将状态空间映射到横纵坐标为这两个主成分的平面空间上

Evaluation

代码研究

Conclusion Remarks