聯(lián)邦學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)FedAvg算法實(shí)現(xiàn)
I. 前言
聯(lián)邦學(xué)習(xí)(Federated Learning) 是人工智能的一個(gè)新的分支,這項(xiàng)技術(shù)是谷歌2016年于論文
Communication-Efficient Learning of Deep Networks from Decentralized Data中首次提出。
在我的另一篇博文聯(lián)邦學(xué)習(xí):《Communication-Efficient Learning of Deep Networks from Decentralized Data中詳細(xì)解析了該篇論文,而本篇博文的目的是利用這篇解讀文章對原始論文中的FedAvg方法進(jìn)行復(fù)現(xiàn)。
因此,閱讀本文前建議先閱讀聯(lián)邦學(xué)習(xí):《Communication-Efficient Learning of Deep Networks from Decentralized Data。
II. 數(shù)據(jù)介紹
聯(lián)邦學(xué)習(xí)中存在多個(gè)客戶端,每個(gè)客戶端都有自己的數(shù)據(jù)集,這個(gè)數(shù)據(jù)集他們是不愿意共享的。
本文選用的數(shù)據(jù)集為中國北方某城市十個(gè)區(qū)/縣從2016年到2019年三年的真實(shí)用電負(fù)荷數(shù)據(jù),采集時(shí)間間隔為1小時(shí),即每一天都有24個(gè)負(fù)荷值。
我們假設(shè)這10個(gè)地區(qū)的電力部門不愿意共享自己的數(shù)據(jù),但是他們又想得到一個(gè)由所有數(shù)據(jù)統(tǒng)一訓(xùn)練得到的全局模型。
除了電力負(fù)荷數(shù)據(jù)意外,還有風(fēng)功率數(shù)據(jù),兩個(gè)數(shù)據(jù)通過參數(shù)type指定:type == 'load’表示負(fù)荷數(shù)據(jù),'wind’表示風(fēng)功率數(shù)據(jù)。
1. 特征構(gòu)造
用某一時(shí)刻前24個(gè)時(shí)刻的負(fù)荷值以及該時(shí)刻的相關(guān)氣象數(shù)據(jù)(如溫度、濕度、壓強(qiáng)等)來預(yù)測該時(shí)刻的負(fù)荷值。
對于風(fēng)功率數(shù)據(jù),同樣使用某一時(shí)刻前24個(gè)時(shí)刻的風(fēng)功率值以及該時(shí)刻的相關(guān)氣象數(shù)據(jù)來預(yù)測該時(shí)刻的風(fēng)功率值。
各個(gè)地區(qū)應(yīng)該就如何制定特征集達(dá)成一致意見,本文使用的各個(gè)地區(qū)上的數(shù)據(jù)的特征是一致的,可以直接使用。
III. 聯(lián)邦學(xué)習(xí)
1. 整體框架
原始論文中提出的FedAvg的框架為:
由于本文中需要利用各個(gè)客戶端的模型參數(shù)來對服務(wù)器端的模型參數(shù)進(jìn)行更新,因此本文決定采用numpy搭建一個(gè)四層的神經(jīng)網(wǎng)絡(luò)模型。模型的具體搭建過程可以參考上一篇博文:從矩陣鏈?zhǔn)角髮?dǎo)的角度來深入理解BP算法(原理+代碼)。在這篇博文里面我詳細(xì)得介紹了神經(jīng)網(wǎng)絡(luò)參數(shù)更新的過程,這將有助于理解本文中的模型參數(shù)更新過程。
神經(jīng)網(wǎng)絡(luò)由1個(gè)輸入層、3個(gè)隱藏層以及1個(gè)輸出層組成,激活函數(shù)全部采用Sigmoid函數(shù)。
網(wǎng)絡(luò)各層間的運(yùn)算關(guān)系,也就是前向傳播過程如下所示:
因此,客戶端參數(shù)更新實(shí)際上就是更新四個(gè) w。
2. 服務(wù)器端
服務(wù)器端執(zhí)行以下步驟:
簡單來說,每一輪通信時(shí)都只是選擇部分客戶端,這些客戶端利用本地的數(shù)據(jù)進(jìn)行參數(shù)更新,然后將更新后的參數(shù)傳給服務(wù)器,服務(wù)器匯總客戶端更新后的參數(shù)形成最新的全局參數(shù)。下一輪通信時(shí),服務(wù)器端將最新的參數(shù)分發(fā)給被選中的客戶端,進(jìn)行下一輪更新。
3. 客戶端
客戶端沒什么可說的,就是利用本地?cái)?shù)據(jù)對神經(jīng)網(wǎng)絡(luò)模型的參數(shù)進(jìn)行更新。
4. 代碼實(shí)現(xiàn)
4.1 初始化
參數(shù):
- K,客戶端數(shù)量,本文為10個(gè),也就是10個(gè)地區(qū)。
- C:選擇率,每一輪通信時(shí)都只是選擇C * K個(gè)客戶端。
- E:客戶端更新本地模型的參數(shù)時(shí),在本地?cái)?shù)據(jù)集上訓(xùn)練E輪。
- B:客戶端更新本地模型的參數(shù)時(shí),本地?cái)?shù)據(jù)集batch大小為B
- r:服務(wù)器端和客戶端一共進(jìn)行r輪通信。
- clients:客戶端集合。
- type:指定數(shù)據(jù)類型,負(fù)荷預(yù)測or風(fēng)功率預(yù)測。
- lr:學(xué)習(xí)率。
- input_dim:數(shù)據(jù)輸入維度。
- nn:全局模型。
- nns: 客戶端模型集合。
代碼實(shí)現(xiàn):
class FedAvg: def __init__(self, options): self.C = options['C'] self.E = options['E'] self.B = options['B'] self.K = options['K'] self.r = options['r'] self.clients = options['clients'] self.type = options['type'] self.lr = options['lr'] self.input_dim = options['input_dim'] self.nn = BP(file_name='server', B=B, E=E, input_dim=self.input_dim, type=self.type, lr=self.lr) self.nns = [] # distribution for i in range(self.K): s = copy.deepcopy(self.nn) s.file_name = self.clients[i] self.nns.append(s)
其中 self.nn是服務(wù)器端初始化的全局參數(shù),由于服務(wù)器端不需要進(jìn)行反向傳播更新參數(shù),因此不需要定義各個(gè)隱層以及輸出。
4.2 服務(wù)器端
服務(wù)器端代碼如下:
def server(self): for t in range(self.r): print('第', t + 1, '輪通信:') m = np.max([int(self.C * self.K), 1]) # sampling index = random.sample(range(0, self.K), m) # dispatch self.dispatch(index) # local updating self.client_update(index) # aggregation self.aggregation(index) # return global model return self.nn
其中client_update(index):
def client_update(self, index): # update nn for k in index: self.nns[k] = train(self.nns[k])
aggregation(index):
def aggregation(self, index): # update w s = 0 for j in index: # normal s += self.nns[j].len w1 = np.zeros_like(self.nn.w1) w2 = np.zeros_like(self.nn.w2) w3 = np.zeros_like(self.nn.w3) w4 = np.zeros_like(self.nn.w4) for j in index: # normal w1 += self.nns[j].w1 * (self.nns[j].len / s) w2 += self.nns[j].w2 * (self.nns[j].len / s) w3 += self.nns[j].w3 * (self.nns[j].len / s) w4 += self.nns[j].w4 * (self.nns[j].len / s) # update server self.nn.w1, self.nn.w2, self.nn.w3, self.nn.w4 = w1, w2, w3, w4
dispatch(index):
def aggregation(self, index): # update w s = 0 for j in index: # normal s += self.nns[j].len w1 = np.zeros_like(self.nn.w1) w2 = np.zeros_like(self.nn.w2) w3 = np.zeros_like(self.nn.w3) w4 = np.zeros_like(self.nn.w4) for j in index: # normal w1 += self.nns[j].w1 * (self.nns[j].len / s) w2 += self.nns[j].w2 * (self.nns[j].len / s) w3 += self.nns[j].w3 * (self.nns[j].len / s) w4 += self.nns[j].w4 * (self.nns[j].len / s) # update server self.nn.w1, self.nn.w2, self.nn.w3, self.nn.w4 = w1, w2, w3, w4
下面對重要代碼進(jìn)行分析:
客戶端的選擇
m = np.max([int(self.C * self.K), 1]) index = random.sample(range(0, self.K), m)
index中存儲中m個(gè)0~10間的整數(shù),表示被選中客戶端的序號。
客戶端的更新
for k in index: self.client_update(self.nns[k])
服務(wù)器端匯總客戶端模型的參數(shù)
關(guān)于模型匯總方式,可以參考一下我的另一篇文章:對FedAvg中模型聚合過程的理解。
當(dāng)然,這只是一種很簡單的匯總方式,還有一些其他類型的匯總方式。論文Electricity Consumer Characteristics Identification: A Federated Learning Approach中總結(jié)了三種匯總方式:
- normal:原始論文中的方式,即根據(jù)樣本數(shù)量來決定客戶端參數(shù)在最終組合時(shí)所占比例。
- LA:根據(jù)客戶端模型的損失占所有客戶端損失和的比重來決定最終組合時(shí)參數(shù)所占比例。
- LS:根據(jù)損失與樣本數(shù)量的乘積所占的比重來決定。
將更新后的參數(shù)分發(fā)給客戶端
def dispatch(self, inidex): # dispatch for i in index: self.nns[i].w1, self.nns[i].w2, self.nns[i].w3, self.nns[ i].w4 = self.nn.w1, self.nn.w2, self.nn.w3, self.nn.w4
4.3 客戶端
客戶端只需要利用本地?cái)?shù)據(jù)來進(jìn)行更新就行了:
def client_update(self, index): # update nn for k in index: self.nns[k] = train(self.nns[k])
其中train():
def train(nn): print('training...') if nn.type == 'load': train_x, train_y, test_x, test_y = nn_seq(nn.file_name, nn.B, nn.type) else: train_x, train_y, test_x, test_y = nn_seq_wind(nn.file_name, nn.B, nn.type) nn.len = len(train_x) batch_size = nn.B epochs = nn.E batch = int(len(train_x) / batch_size) for epoch in range(epochs): for i in range(batch): start = i * batch_size end = start + batch_size nn.forward_prop(train_x[start:end], train_y[start:end]) nn.backward_prop(train_y[start:end]) print('當(dāng)前epoch:', epoch, ' error:', np.mean(nn.loss)) return nn
4.4 測試
def global_test(self): model = self.nn c = clients if self.type == 'load' else clients_wind for client in c: model.file_name = client test(model)
IV. 實(shí)驗(yàn)及結(jié)果
本次實(shí)驗(yàn)的參數(shù)選擇為:
K | C | E | B | r |
---|---|---|---|---|
10 | 0.5 | 50 | 50 | 5 |
if __name__ == '__main__': K, C, E, B, r = 10, 0.5, 50, 50, 5 type = 'load' input_dim = 30 if type == 'load' else 28 _client = clients if type == 'load' else clients_wind lr = 0.08 options = {<!--{C}%3C!%2D%2D%20%2D%2D%3E-->'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client, 'input_dim': input_dim, 'lr': lr} fedavg = FedAvg(options) fedavg.server() fedavg.global_test()if __name__ == '__main__': K, C, E, B, r = 10, 0.5, 50, 50, 5 type = 'load' input_dim = 30 if type == 'load' else 28 _client = clients if type == 'load' else clients_wind lr = 0.08 options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client, 'input_dim': input_dim, 'lr': lr} fedavg = FedAvg(options) fedavg.server() fedavg.global_test()
各個(gè)客戶端單獨(dú)訓(xùn)練(訓(xùn)練50輪,batch大小為50)后在本地的測試集上的表現(xiàn)為:
客戶端編號 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
MAPE / % | 5.79 | 6.73 | 6.18 | 5.82 | 5.49 | 4.55 | 6.23 | 9.59 | 4.84 | 5.49 |
可以看到,由于各個(gè)客戶端的數(shù)據(jù)都十分充足,所以每個(gè)客戶端自己訓(xùn)練的本地模型的預(yù)測精度已經(jīng)很高了。
服務(wù)器與客戶端通信5輪后,服務(wù)器上的全局模型在10個(gè)客戶端測試集上的表現(xiàn)如下所示:
客戶端編號 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
MAPE / % | 6.58 | 4.19 | 3.17 | 5.13 | 3.58 | 4.69 | 4.71 | 3.75 | 2.94 | 4.77 |
可以看到,經(jīng)過聯(lián)邦學(xué)習(xí)框架得到全局模型在各個(gè)客戶端上表現(xiàn)同樣很好,這是因?yàn)槭畟€(gè)地區(qū)上的數(shù)據(jù)是獨(dú)立同分布的。
V. 源碼及數(shù)據(jù)
我把數(shù)據(jù)和代碼放在了GitHub上:FedAvg
以上就是聯(lián)邦學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)FedAvg算法實(shí)現(xiàn)的詳細(xì)內(nèi)容,更多關(guān)于神經(jīng)網(wǎng)絡(luò)FedAvg算法的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python3通過Luhn算法快速驗(yàn)證信用卡卡號的方法
這篇文章主要介紹了Python3通過Luhn算法快速驗(yàn)證信用卡卡號的方法,涉及Python中Luhn算法的使用技巧,非常簡單實(shí)用,需要的朋友可以參考下2015-05-05Python tkinter之Bind(綁定事件)的使用示例
這篇文章主要介紹了Python tkinter之Bind(綁定事件)的使用詳解,幫助大家更好的理解和學(xué)習(xí)python的gui開發(fā),感興趣的朋友可以了解下2021-02-02在python image 中實(shí)現(xiàn)安裝中文字體
這篇文章主要介紹了在python image 中實(shí)現(xiàn)安裝中文字體,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-05-05Python對圖片進(jìn)行resize、裁剪、旋轉(zhuǎn)、翻轉(zhuǎn)問題
這篇文章主要介紹了Python對圖片進(jìn)行resize、裁剪、旋轉(zhuǎn)、翻轉(zhuǎn)問題,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-05-05