聯(lián)邦學習論文解讀分散數(shù)據(jù)的深層網(wǎng)絡通信
前言
聯(lián)邦學習(Federated Learning) 是人工智能的一個新的分支,這項技術是谷歌于2016年首次提出,本篇論文第一次描述了這一概念。
Abstract
現(xiàn)代移動設備可以訪問到大量數(shù)據(jù),這些數(shù)據(jù)訓練后反過來可以大大提高用戶體驗。例如,語言模型可以改善語音識別和文本輸入,圖像模型可以自動選擇好的照片。但是,這些豐富的數(shù)據(jù)通常對隱私敏感、數(shù)量眾多或兩者兼而有之,這可能會妨礙使用常規(guī)方法進行訓練。于是我們提出將訓練數(shù)據(jù)分發(fā)在移動設備上的替代訓練方案,并通過聚合本地計算的更新來學習共享模型,我們稱這種分散的學習方法為聯(lián)邦學習。
簡而言之,當下移動設備產(chǎn)生了大量的數(shù)據(jù),我們需要利用這些數(shù)據(jù)來訓練一些模型,這些模型將會提升用戶實驗。傳統(tǒng)的訓練方式:收集所有客戶端的數(shù)據(jù),然后利用這些數(shù)據(jù)訓練一個模型,最后分發(fā)給所有客戶端。存在的問題:我們沒法直接收集所有設備的數(shù)據(jù)來統(tǒng)一訓練(隱私要求),于是提出了一種新的不需要共享客戶端數(shù)據(jù)的模型訓練方式。
Introduction
聯(lián)邦學習中,學習任務由中央服務器協(xié)調(diào),每個客戶端都有一個本地訓練數(shù)據(jù)集,該數(shù)據(jù)集永遠不會上傳到服務器(即隱私不會被泄露)。
本文主要貢獻:
- 將移動設備分散數(shù)據(jù)的訓練問題確定為重要的研究方向
- 提出了解決該問題的具體算法
- 對所提出的算法進行了驗證
更具體地說,我們引入了聯(lián)邦平均算法(FederatedAveraging algorithm)。
Federated Learning
聯(lián)邦學習的問題具有以下屬性:
- 對來自移動設備的數(shù)據(jù)進行訓練,與對數(shù)據(jù)中心通??捎玫拇頂?shù)據(jù)進行訓練相比,具有明顯的優(yōu)勢。
- 該數(shù)據(jù)是隱私敏感的或者大規(guī)模的(與模型的大小相比),因此最好不要純粹出于模型訓練的目的將其記錄到數(shù)據(jù)中心(隱私的)
- 對于監(jiān)督任務,可以從用戶交互中自然推斷出數(shù)據(jù)上的標簽。
作為兩個例子,我們考慮圖像分類和語言模型。圖像分類:例如預測哪些照片將來最有可能被多次查看或共享;語言模型:下一個單詞的預測甚至預測整個回復來改善觸摸屏鍵盤上的語音識別和文本輸入。這兩項任務的潛在訓練數(shù)據(jù)(用戶拍攝的所有照片以及他們在移動鍵盤上鍵入的所有照片,包括密碼,URL,消息等)都可能對隱私敏感。
Privacy
與數(shù)據(jù)中心對持久數(shù)據(jù)的訓練相比,聯(lián)邦學習具有明顯的隱私優(yōu)勢。但是即使是“匿名”數(shù)據(jù)集,也可能通過與其他數(shù)據(jù)結合而使用戶隱私面臨風險。
Federated Optimization
我們將聯(lián)邦學習中的優(yōu)化問題稱為聯(lián)邦優(yōu)化(Federated Optimization)。聯(lián)邦優(yōu)化具有幾個關鍵屬性,可將其與典型的分布式優(yōu)化問題區(qū)分開:
- Non-IID:給定客戶端上的訓練數(shù)據(jù)通?;谔囟ㄓ脩魧σ苿釉O備的使用,因此任何特定用戶的本地數(shù)據(jù)集將不代表總體分布。
- Unbalanced:一些用戶將比其他用戶更重地使用服務或應用程序,導致不同數(shù)量的本地培訓數(shù)據(jù)。簡而言之,每個用戶產(chǎn)生的數(shù)據(jù)量不一樣。
- Massively distributed:預計參與優(yōu)化的客戶端數(shù)量將遠遠大于每個客戶端的平均示例數(shù)量。
- 移動設備經(jīng)常脫機或連接緩慢或昂貴
本文重點是非IID和不平衡屬性的優(yōu)化,以及通信約束的關鍵性質。
我們假設一個同步更新方案在幾輪通訊中進行。有一組固定的K個客戶端,每個客戶端都有一個固定的本地數(shù)據(jù)集。在每輪開始時,隨機選擇一部分客戶端,服務器將當前全局算法狀態(tài)發(fā)送給這些客戶端中的每一個(例如,當前模型參數(shù))。然后,每個選定的客戶端根據(jù)全局狀態(tài)及其本地數(shù)據(jù)集執(zhí)行本地計算,并向服務器發(fā)送更新。然后,服務器將這些更新應用于其全局狀態(tài),并重復該過程。
問題的一般形式:
在數(shù)據(jù)中心優(yōu)化中,通信成本相對較小,計算成本占主導地位,最近的重點是使用GPU來降低這些成本。相比之下,在聯(lián)邦優(yōu)化通信成本中占主導地位。
因此,我們的目標是使用額外的計算來減少訓練模型所需的通信輪數(shù)。我們可以添加計算的兩種主要方法:
增加并行性。使用更多客戶端在每個通信周期之間獨立工作。
增加對每個客戶端的計算。即每個客戶端在每個通信回合之間執(zhí)行更復雜的計算。
以上內(nèi)容下文都將有更加詳細的介紹!
The FederatedAveraging Algorithm
深度學習的眾多成功應用幾乎完全依賴于隨機梯度下降(SGD)的變體進行優(yōu)化。
在聯(lián)邦學習中,我們使用大批量同步SGD,已有相關論文證明,它是優(yōu)于異步方法的。
為了在聯(lián)邦學習中應用這種方法,我們在每輪中選擇一部分客戶端,并計算這些客戶端持有的所有數(shù)據(jù)的損失梯度。參數(shù)C控制全局塊大小,其中C=1對應于全批(非隨機)梯度下降。我們將此算法稱為FederatedSGD(orFedSGD)。
FedSGD的一種典型的實現(xiàn)方式:C=1(非SGD),學習率 η \eta η固定,每一個客戶端算出自己所有數(shù)據(jù)損失的梯度(平均梯度),然后傳遞給中央服務器,中央服務器整合所有梯度,來更新全局的參數(shù) w t w_t wt?。
計算量由三個參數(shù)控制:
- C:每一輪執(zhí)行計算的客戶端比例(只有一部分客戶端參與更新)
- E:每一輪更新時,每個客戶端對其本地參數(shù)進行更新的次數(shù)
- B:客戶端每一次更新參數(shù)時所用本地數(shù)據(jù)量的大小
該算法更加詳細的描述如下:
參數(shù)介紹:K表示客戶端的個數(shù), B表示每一次本地更新時的數(shù)據(jù)量,E表示本地更新的次數(shù), η表示學習率。
首先是服務器執(zhí)行以下步驟:
Experimental Results
Table1:
表1描述的是圖像分類任務:參數(shù)C對E=1的MNIST 2NN和E=5的CNN的影響。其中C=0表示每次選擇一個客戶端的數(shù)據(jù)進行更新。對于MINST 2NN來說,總的客戶端數(shù)量為100,即五行分別表示1,10,20,50,100個客戶端。
每個表格條目給出了實現(xiàn)2NN的97%和CNN的99%的測試集精度所需的通信輪數(shù),以及相對于C=0這一baseline的加速比。 比如對于第三行 B = ∞ B=\infty B=∞這一情況( B = ∞ B=\infty B=∞表示每一次都用全部數(shù)據(jù)進行本地參數(shù)更新),中央服務器需要與客戶端進行1658次通信,才能使得模型在測試集上的精度達到97%。
Table2:
表2描述的是語言模型:LSTM語言模型,該模型在讀取一行中的每個字符后預測下一個字符。該模型以一系列字符作為輸入,并將每個字符嵌入到8維空間中,然后通過2個LSTM層處理嵌入的字符,每個層具有256個節(jié)點。
表2的含義同表1:在某一參數(shù)環(huán)境下,F(xiàn)edSGD要達到目標精度所需要進行的通訊次數(shù)。
SGD對學習率參數(shù)η的調(diào)整很敏感,本文的 η \eta η是基于網(wǎng)格搜索法找到的。
Increasing parallelism
增加并行性: 即增加客戶端數(shù)量。
上圖給出了特定參數(shù)設置下要達到閾值精度(圖中灰線)所需要進行的通訊輪數(shù)。
然后,使用形成曲線的離散點之間的線性插值來計算曲線穿過目標精度的輪數(shù)。
Increasing computation per client
增加每個客戶端的計算量。C=0.1固定,減小B,或者增加E,或者減小B的同時增加E。
還是上面這張圖:
可以看到,隨著B減小或者E增加,達到目標精度所需的通訊次數(shù)是減小的,也就是說:每輪添加更多本地SGD更新可以顯著降低通信成本。
Can we over-optimize on the client datasets?
本地數(shù)據(jù)集上進行更新時可以過度優(yōu)化嗎?即E特別大,進行很多次的本地更新。
上圖給出了E特別大時的實驗結果:對于大的E值,收斂速度并沒有顯著的下降。
Conclusions and Future Work
聯(lián)邦學習可以變得切實可行,因為可以使用相對較少的通信輪次來訓練高質量模型。聯(lián)邦學習將是未來比較熱門的一個方向!
以上就是論文解讀分散數(shù)據(jù)的深層網(wǎng)絡通信有效學習的詳細內(nèi)容,更多關于分散數(shù)據(jù)深層網(wǎng)絡通信的資料請關注腳本之家其它相關文章!
相關文章
提示“處理URL時服務器出錯”和“HTTP 500錯誤“的解決方法
關于提示“處理URL時服務器出錯”和“HTTP 500錯誤“的解決方法,需要的朋友可以參考下。2009-11-11深入解析HetuEngine實現(xiàn)On Yarn原理
這篇文章主要介紹了HetuEngine實現(xiàn)On Yarn原理,介紹了HetuEngine On Yarn的原理,其實現(xiàn)主要是借助了Yarn Service提供的能力,感興趣的朋友一起通過本文學習下2022-01-01