欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

聯(lián)邦學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)FedAvg算法實(shí)現(xiàn)

 更新時(shí)間:2022年05月11日 11:49:05   作者:Cyril_KI  
這篇文章主要為大家介紹了聯(lián)邦學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)FedAvg算法實(shí)現(xiàn),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

I. 前言

聯(lián)邦學(xué)習(xí)(Federated Learning) 是人工智能的一個(gè)新的分支,這項(xiàng)技術(shù)是谷歌2016年于論文

Communication-Efficient Learning of Deep Networks from Decentralized Data中首次提出。

在我的另一篇博文聯(lián)邦學(xué)習(xí):《Communication-Efficient Learning of Deep Networks from Decentralized Data中詳細(xì)解析了該篇論文,而本篇博文的目的是利用這篇解讀文章對原始論文中的FedAvg方法進(jìn)行復(fù)現(xiàn)。

因此,閱讀本文前建議先閱讀聯(lián)邦學(xué)習(xí):《Communication-Efficient Learning of Deep Networks from Decentralized Data。

II. 數(shù)據(jù)介紹

聯(lián)邦學(xué)習(xí)中存在多個(gè)客戶端,每個(gè)客戶端都有自己的數(shù)據(jù)集,這個(gè)數(shù)據(jù)集他們是不愿意共享的。

本文選用的數(shù)據(jù)集為中國北方某城市十個(gè)區(qū)/縣從2016年到2019年三年的真實(shí)用電負(fù)荷數(shù)據(jù),采集時(shí)間間隔為1小時(shí),即每一天都有24個(gè)負(fù)荷值。

我們假設(shè)這10個(gè)地區(qū)的電力部門不愿意共享自己的數(shù)據(jù),但是他們又想得到一個(gè)由所有數(shù)據(jù)統(tǒng)一訓(xùn)練得到的全局模型。

除了電力負(fù)荷數(shù)據(jù)意外,還有風(fēng)功率數(shù)據(jù),兩個(gè)數(shù)據(jù)通過參數(shù)type指定:type == 'load’表示負(fù)荷數(shù)據(jù),'wind’表示風(fēng)功率數(shù)據(jù)。

1. 特征構(gòu)造

用某一時(shí)刻前24個(gè)時(shí)刻的負(fù)荷值以及該時(shí)刻的相關(guān)氣象數(shù)據(jù)(如溫度、濕度、壓強(qiáng)等)來預(yù)測該時(shí)刻的負(fù)荷值。

對于風(fēng)功率數(shù)據(jù),同樣使用某一時(shí)刻前24個(gè)時(shí)刻的風(fēng)功率值以及該時(shí)刻的相關(guān)氣象數(shù)據(jù)來預(yù)測該時(shí)刻的風(fēng)功率值。

各個(gè)地區(qū)應(yīng)該就如何制定特征集達(dá)成一致意見,本文使用的各個(gè)地區(qū)上的數(shù)據(jù)的特征是一致的,可以直接使用。

III. 聯(lián)邦學(xué)習(xí)

1. 整體框架

原始論文中提出的FedAvg的框架為:

由于本文中需要利用各個(gè)客戶端的模型參數(shù)來對服務(wù)器端的模型參數(shù)進(jìn)行更新,因此本文決定采用numpy搭建一個(gè)四層的神經(jīng)網(wǎng)絡(luò)模型。模型的具體搭建過程可以參考上一篇博文:從矩陣鏈?zhǔn)角髮?dǎo)的角度來深入理解BP算法(原理+代碼)。在這篇博文里面我詳細(xì)得介紹了神經(jīng)網(wǎng)絡(luò)參數(shù)更新的過程,這將有助于理解本文中的模型參數(shù)更新過程。

神經(jīng)網(wǎng)絡(luò)由1個(gè)輸入層、3個(gè)隱藏層以及1個(gè)輸出層組成,激活函數(shù)全部采用Sigmoid函數(shù)。

網(wǎng)絡(luò)各層間的運(yùn)算關(guān)系,也就是前向傳播過程如下所示:

因此,客戶端參數(shù)更新實(shí)際上就是更新四個(gè) w。

2. 服務(wù)器端

服務(wù)器端執(zhí)行以下步驟:

簡單來說,每一輪通信時(shí)都只是選擇部分客戶端,這些客戶端利用本地的數(shù)據(jù)進(jìn)行參數(shù)更新,然后將更新后的參數(shù)傳給服務(wù)器,服務(wù)器匯總客戶端更新后的參數(shù)形成最新的全局參數(shù)。下一輪通信時(shí),服務(wù)器端將最新的參數(shù)分發(fā)給被選中的客戶端,進(jìn)行下一輪更新。

3. 客戶端

客戶端沒什么可說的,就是利用本地?cái)?shù)據(jù)對神經(jīng)網(wǎng)絡(luò)模型的參數(shù)進(jìn)行更新。

4. 代碼實(shí)現(xiàn)

4.1 初始化

參數(shù):

  • K,客戶端數(shù)量,本文為10個(gè),也就是10個(gè)地區(qū)。
  • C:選擇率,每一輪通信時(shí)都只是選擇C * K個(gè)客戶端。
  • E:客戶端更新本地模型的參數(shù)時(shí),在本地?cái)?shù)據(jù)集上訓(xùn)練E輪。
  • B:客戶端更新本地模型的參數(shù)時(shí),本地?cái)?shù)據(jù)集batch大小為B
  • r:服務(wù)器端和客戶端一共進(jìn)行r輪通信。
  • clients:客戶端集合。
  • type:指定數(shù)據(jù)類型,負(fù)荷預(yù)測or風(fēng)功率預(yù)測。
  • lr:學(xué)習(xí)率。
  • input_dim:數(shù)據(jù)輸入維度。
  • nn:全局模型。
  • nns: 客戶端模型集合。

代碼實(shí)現(xiàn):

class FedAvg:
    def __init__(self, options):
        self.C = options['C']
        self.E = options['E']
        self.B = options['B']
        self.K = options['K']
        self.r = options['r']
        self.clients = options['clients']
        self.type = options['type']
        self.lr = options['lr']
        self.input_dim = options['input_dim']
        self.nn = BP(file_name='server', B=B, E=E, input_dim=self.input_dim, type=self.type, lr=self.lr)
        self.nns = []
        # distribution
        for i in range(self.K):
            s = copy.deepcopy(self.nn)
            s.file_name = self.clients[i]
            self.nns.append(s)

其中 self.nn是服務(wù)器端初始化的全局參數(shù),由于服務(wù)器端不需要進(jìn)行反向傳播更新參數(shù),因此不需要定義各個(gè)隱層以及輸出。

4.2 服務(wù)器端

服務(wù)器端代碼如下:

def server(self):
     for t in range(self.r):
          print('第', t + 1, '輪通信:')
          m = np.max([int(self.C * self.K), 1])
          # sampling
          index = random.sample(range(0, self.K), m)
          # dispatch
          self.dispatch(index)
          # local updating
          self.client_update(index)
          # aggregation
          self.aggregation(index)

     # return global model
     return self.nn

其中client_update(index):

def client_update(self, index):  # update nn
     for k in index:
          self.nns[k] = train(self.nns[k])

aggregation(index):

def aggregation(self, index):
     # update w
     s = 0
     for j in index:
          # normal
          s += self.nns[j].len
          
     w1 = np.zeros_like(self.nn.w1)
     w2 = np.zeros_like(self.nn.w2)
     w3 = np.zeros_like(self.nn.w3)
     w4 = np.zeros_like(self.nn.w4)
     
     for j in index:
          # normal
          w1 += self.nns[j].w1 * (self.nns[j].len / s)
          w2 += self.nns[j].w2 * (self.nns[j].len / s)
          w3 += self.nns[j].w3 * (self.nns[j].len / s)
          w4 += self.nns[j].w4 * (self.nns[j].len / s)
     # update server
     self.nn.w1, self.nn.w2, self.nn.w3, self.nn.w4 = w1, w2, w3, w4

dispatch(index):

def aggregation(self, index):
     # update w
     s = 0
     for j in index:
          # normal
          s += self.nns[j].len
          
     w1 = np.zeros_like(self.nn.w1)
     w2 = np.zeros_like(self.nn.w2)
     w3 = np.zeros_like(self.nn.w3)
     w4 = np.zeros_like(self.nn.w4)
     
     for j in index:
          # normal
          w1 += self.nns[j].w1 * (self.nns[j].len / s)
          w2 += self.nns[j].w2 * (self.nns[j].len / s)
          w3 += self.nns[j].w3 * (self.nns[j].len / s)
          w4 += self.nns[j].w4 * (self.nns[j].len / s)
     # update server
     self.nn.w1, self.nn.w2, self.nn.w3, self.nn.w4 = w1, w2, w3, w4

下面對重要代碼進(jìn)行分析:

客戶端的選擇

m = np.max([int(self.C * self.K), 1])
index = random.sample(range(0, self.K), m)

index中存儲中m個(gè)0~10間的整數(shù),表示被選中客戶端的序號。

客戶端的更新

for k in index:
    self.client_update(self.nns[k])

服務(wù)器端匯總客戶端模型的參數(shù)

關(guān)于模型匯總方式,可以參考一下我的另一篇文章:對FedAvg中模型聚合過程的理解

當(dāng)然,這只是一種很簡單的匯總方式,還有一些其他類型的匯總方式。論文Electricity Consumer Characteristics Identification: A Federated Learning Approach中總結(jié)了三種匯總方式:

  • normal:原始論文中的方式,即根據(jù)樣本數(shù)量來決定客戶端參數(shù)在最終組合時(shí)所占比例。
  • LA:根據(jù)客戶端模型的損失占所有客戶端損失和的比重來決定最終組合時(shí)參數(shù)所占比例。
  • LS:根據(jù)損失與樣本數(shù)量的乘積所占的比重來決定。

將更新后的參數(shù)分發(fā)給客戶端

def dispatch(self, inidex):
     # dispatch
     for i in index:
          self.nns[i].w1, self.nns[i].w2, self.nns[i].w3, self.nns[
               i].w4 = self.nn.w1, self.nn.w2, self.nn.w3, self.nn.w4

4.3 客戶端

客戶端只需要利用本地?cái)?shù)據(jù)來進(jìn)行更新就行了:

def client_update(self, index):  # update nn
     for k in index:
          self.nns[k] = train(self.nns[k])

其中train():

def train(nn):
    print('training...')
    if nn.type == 'load':
        train_x, train_y, test_x, test_y = nn_seq(nn.file_name, nn.B, nn.type)
    else:
        train_x, train_y, test_x, test_y = nn_seq_wind(nn.file_name, nn.B, nn.type)
    nn.len = len(train_x)
    batch_size = nn.B
    epochs = nn.E
    batch = int(len(train_x) / batch_size)
    for epoch in range(epochs):
        for i in range(batch):
            start = i * batch_size
            end = start + batch_size
            nn.forward_prop(train_x[start:end], train_y[start:end])
            nn.backward_prop(train_y[start:end])
        print('當(dāng)前epoch:', epoch, ' error:', np.mean(nn.loss))
    return nn

4.4 測試

def global_test(self):
     model = self.nn
     c = clients if self.type == 'load' else clients_wind
     for client in c:
          model.file_name = client
          test(model)

IV. 實(shí)驗(yàn)及結(jié)果

本次實(shí)驗(yàn)的參數(shù)選擇為:

KCEBr
100.550505
if __name__ == '__main__': K, C, E, B, r = 10, 0.5, 50, 50, 5 type = 'load' input_dim = 30 if type == 'load' else 28 _client = clients if type == 'load' else clients_wind lr = 0.08 options = {<!--{C}%3C!%2D%2D%20%2D%2D%3E-->'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client, 'input_dim': input_dim, 'lr': lr} fedavg = FedAvg(options) fedavg.server() fedavg.global_test()if __name__ == '__main__':
    K, C, E, B, r = 10, 0.5, 50, 50, 5
    type = 'load'
    input_dim = 30 if type == 'load' else 28
    _client = clients if type == 'load' else clients_wind
    lr = 0.08
    options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client,
               'input_dim': input_dim, 'lr': lr}
    fedavg = FedAvg(options)
    fedavg.server()
    fedavg.global_test()

各個(gè)客戶端單獨(dú)訓(xùn)練(訓(xùn)練50輪,batch大小為50)后在本地的測試集上的表現(xiàn)為:

客戶端編號12345678910
MAPE / %5.796.736.185.825.494.556.239.594.845.49

可以看到,由于各個(gè)客戶端的數(shù)據(jù)都十分充足,所以每個(gè)客戶端自己訓(xùn)練的本地模型的預(yù)測精度已經(jīng)很高了。

服務(wù)器與客戶端通信5輪后,服務(wù)器上的全局模型在10個(gè)客戶端測試集上的表現(xiàn)如下所示:

客戶端編號12345678910
MAPE / %6.584.193.175.133.584.694.713.752.944.77

可以看到,經(jīng)過聯(lián)邦學(xué)習(xí)框架得到全局模型在各個(gè)客戶端上表現(xiàn)同樣很好,這是因?yàn)槭畟€(gè)地區(qū)上的數(shù)據(jù)是獨(dú)立同分布的。

V. 源碼及數(shù)據(jù)

我把數(shù)據(jù)和代碼放在了GitHub上:FedAvg

以上就是聯(lián)邦學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)FedAvg算法實(shí)現(xiàn)的詳細(xì)內(nèi)容,更多關(guān)于神經(jīng)網(wǎng)絡(luò)FedAvg算法的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • python垃圾回收機(jī)制(GC)原理解析

    python垃圾回收機(jī)制(GC)原理解析

    這篇文章主要介紹了python垃圾回收機(jī)制(GC)原理解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-12-12
  • Python3通過Luhn算法快速驗(yàn)證信用卡卡號的方法

    Python3通過Luhn算法快速驗(yàn)證信用卡卡號的方法

    這篇文章主要介紹了Python3通過Luhn算法快速驗(yàn)證信用卡卡號的方法,涉及Python中Luhn算法的使用技巧,非常簡單實(shí)用,需要的朋友可以參考下
    2015-05-05
  • python3實(shí)現(xiàn)彈彈球小游戲

    python3實(shí)現(xiàn)彈彈球小游戲

    這篇文章主要介紹了python3實(shí)現(xiàn)彈彈球小游戲,圖形用戶界面tkinter,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2019-11-11
  • Python tkinter之Bind(綁定事件)的使用示例

    Python tkinter之Bind(綁定事件)的使用示例

    這篇文章主要介紹了Python tkinter之Bind(綁定事件)的使用詳解,幫助大家更好的理解和學(xué)習(xí)python的gui開發(fā),感興趣的朋友可以了解下
    2021-02-02
  • 在python image 中實(shí)現(xiàn)安裝中文字體

    在python image 中實(shí)現(xiàn)安裝中文字體

    這篇文章主要介紹了在python image 中實(shí)現(xiàn)安裝中文字體,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-05-05
  • python實(shí)現(xiàn)購物車功能

    python實(shí)現(xiàn)購物車功能

    這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)購物車功能,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2022-02-02
  • Python特效之?dāng)?shù)字成像方法詳解

    Python特效之?dāng)?shù)字成像方法詳解

    所謂數(shù)字成像,即將原圖片經(jīng)過python處理后,生成完全由純數(shù)字組成的圖像。本文將具體為大家介紹一下這一效果如何實(shí)現(xiàn),需要的可以參考一下
    2022-01-01
  • 使用Python通過簡單操作設(shè)置PDF文檔屬性

    使用Python通過簡單操作設(shè)置PDF文檔屬性

    PDF文檔屬性是嵌入在PDF文檔中的一些與文檔有關(guān)的信息,這篇文章主要為大家介紹了如何使用Python通過簡單的操作設(shè)置PDF文檔屬性,需要的可以參考下
    2023-11-11
  • Python對圖片進(jìn)行resize、裁剪、旋轉(zhuǎn)、翻轉(zhuǎn)問題

    Python對圖片進(jìn)行resize、裁剪、旋轉(zhuǎn)、翻轉(zhuǎn)問題

    這篇文章主要介紹了Python對圖片進(jìn)行resize、裁剪、旋轉(zhuǎn)、翻轉(zhuǎn)問題,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-05-05
  • Selenium常見八大定位法總結(jié)

    Selenium常見八大定位法總結(jié)

    自動(dòng)化最基礎(chǔ)的就屬于定位元素了,元素不會定位,基本上已經(jīng)團(tuán)滅了,就不用再去考慮什么自動(dòng)化了,下面這篇文章主要給大家介紹了關(guān)于Selenium常見八大定位法的相關(guān)資料,需要的朋友可以參考下
    2023-02-02

最新評論