欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)的基本算法FedAvg

 更新時(shí)間:2022年05月11日 14:02:29   作者:Cyril_KI  
這篇文章主要為大家介紹了PyTorch實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)的基本算法FedAvg,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

I. 前言

在之前的一篇博客聯(lián)邦學(xué)習(xí)基本算法FedAvg的代碼實(shí)現(xiàn)中利用numpy手搭神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)了FedAvg,手搭的神經(jīng)網(wǎng)絡(luò)效果已經(jīng)很好了,不過(guò)這還是屬于自己造輪子,建議優(yōu)先使用PyTorch來(lái)實(shí)現(xiàn)。

II. 數(shù)據(jù)介紹

聯(lián)邦學(xué)習(xí)中存在多個(gè)客戶(hù)端,每個(gè)客戶(hù)端都有自己的數(shù)據(jù)集,這個(gè)數(shù)據(jù)集他們是不愿意共享的。

本文選用的數(shù)據(jù)集為中國(guó)北方某城市十個(gè)區(qū)/縣從2016年到2019年三年的真實(shí)用電負(fù)荷數(shù)據(jù),采集時(shí)間間隔為1小時(shí),即每一天都有24個(gè)負(fù)荷值。

我們假設(shè)這10個(gè)地區(qū)的電力部門(mén)不愿意共享自己的數(shù)據(jù),但是他們又想得到一個(gè)由所有數(shù)據(jù)統(tǒng)一訓(xùn)練得到的全局模型。

除了電力負(fù)荷數(shù)據(jù)以外,還有一個(gè)備選數(shù)據(jù)集:風(fēng)功率數(shù)據(jù)集。兩個(gè)數(shù)據(jù)集通過(guò)參數(shù)type指定:type == 'load’表示負(fù)荷數(shù)據(jù),'wind’表示風(fēng)功率數(shù)據(jù)。

特征構(gòu)造

用某一時(shí)刻前24個(gè)時(shí)刻的負(fù)荷值以及該時(shí)刻的相關(guān)氣象數(shù)據(jù)(如溫度、濕度、壓強(qiáng)等)來(lái)預(yù)測(cè)該時(shí)刻的負(fù)荷值。

對(duì)于風(fēng)功率數(shù)據(jù),同樣使用某一時(shí)刻前24個(gè)時(shí)刻的風(fēng)功率值以及該時(shí)刻的相關(guān)氣象數(shù)據(jù)來(lái)預(yù)測(cè)該時(shí)刻的風(fēng)功率值。

各個(gè)地區(qū)應(yīng)該就如何制定特征集達(dá)成一致意見(jiàn),本文使用的各個(gè)地區(qū)上的數(shù)據(jù)的特征是一致的,可以直接使用。

III. 聯(lián)邦學(xué)習(xí)

1. 整體框架

原始論文中提出的FedAvg的框架為:

在這里插入圖片描述

客戶(hù)端模型采用PyTorch搭建:

class ANN(nn.Module):
    def __init__(self, input_dim, name, B, E, type, lr):
        super(ANN, self).__init__()
        self.name = name
        self.B = B
        self.E = E
        self.len = 0
        self.type = type
        self.lr = lr
        self.loss = 0
        self.fc1 = nn.Linear(input_dim, 20)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout()
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 20)
        self.fc4 = nn.Linear(20, 1)
    def forward(self, data):
        x = self.fc1(data)
        x = self.sigmoid(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x

2. 服務(wù)器端

服務(wù)器端執(zhí)行以下步驟:

簡(jiǎn)單來(lái)說(shuō),每一輪通信時(shí)都只是選擇部分客戶(hù)端,這些客戶(hù)端利用本地的數(shù)據(jù)進(jìn)行參數(shù)更新,然后將更新后的參數(shù)傳給服務(wù)器,服務(wù)器匯總客戶(hù)端更新后的參數(shù)形成最新的全局參數(shù)。下一輪通信時(shí),服務(wù)器端將最新的參數(shù)分發(fā)給被選中的客戶(hù)端,進(jìn)行下一輪更新。

3. 客戶(hù)端

客戶(hù)端沒(méi)什么可說(shuō)的,就是利用本地?cái)?shù)據(jù)對(duì)神經(jīng)網(wǎng)絡(luò)模型的參數(shù)進(jìn)行更新。

IV. 代碼實(shí)現(xiàn)

1. 初始化

class FedAvg:
    def __init__(self, options):
        self.C = options['C']
        self.E = options['E']
        self.B = options['B']
        self.K = options['K']
        self.r = options['r']
        self.input_dim = options['input_dim']
        self.type = options['type']
        self.lr = options['lr']
        self.clients = options['clients']
        self.nn = ANN(input_dim=self.input_dim, name='server', B=B, E=E, type=self.type, lr=self.lr).to(device)
        self.nns = []
        for i in range(K):
            temp = copy.deepcopy(self.nn)
            temp.name = self.clients[i]
            self.nns.append(temp)

參數(shù):

  • K,客戶(hù)端數(shù)量,本文為10個(gè),也就是10個(gè)地區(qū)。
  • C:選擇率,每一輪通信時(shí)都只是選擇C * K個(gè)客戶(hù)端。
  • E:客戶(hù)端更新本地模型的參數(shù)時(shí),在本地?cái)?shù)據(jù)集上訓(xùn)練E輪。
  • B:客戶(hù)端更新本地模型的參數(shù)時(shí),本地?cái)?shù)據(jù)集batch大小為B
  • r:服務(wù)器端和客戶(hù)端一共進(jìn)行r輪通信。
  • clients:客戶(hù)端集合。
  • type:指定數(shù)據(jù)類(lèi)型,負(fù)荷預(yù)測(cè)or風(fēng)功率預(yù)測(cè)。
  • lr:學(xué)習(xí)率。
  • input_dim:數(shù)據(jù)輸入維度。
  • nn:全局模型。
  • nns: 客戶(hù)端模型集合。

2. 服務(wù)器端

服務(wù)器端代碼如下:

def server(self):
     for t in range(self.r):
          print('第', t + 1, '輪通信:')
          m = np.max([int(self.C * self.K), 1])
          # sampling
          index = random.sample(range(0, self.K), m)
          # dispatch
          self.dispatch(index)
          # local updating
          self.client_update(index)
          # aggregation
          self.aggregation(index)
     # return global model
     return self.nn

其中client_update(index):

def client_update(self, index):  # update nn
     for k in index:
          self.nns[k] = train(self.nns[k])

aggregation(index):

def aggregation(self, index):
     s = 0
     for j in index:
          # normal
          s += self.nns[j].len
     params = {}
     with torch.no_grad():
          for k, v in self.nns[0].named_parameters():
               params[k] = copy.deepcopy(v)
               params[k].zero_()
     for j in index:
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    params[k] += v * (self.nns[j].len / s)
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               v.copy_(params[k])

dispatch(index):

def dispatch(self, index):
     params = {}
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               params[k] = copy.deepcopy(v)
     for j in index:
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    v.copy_(params[k])

下面對(duì)重要代碼進(jìn)行分析:

客戶(hù)端的選擇

m = np.max([int(self.C * self.K), 1])
index = random.sample(range(0, self.K), m)

index中存儲(chǔ)中m個(gè)0~10間的整數(shù),表示被選中客戶(hù)端的序號(hào)。

客戶(hù)端的更新

for k in index:
    self.client_update(self.nns[k])

服務(wù)器端匯總客戶(hù)端模型的參數(shù)

關(guān)于模型匯總方式,可以參考一下我的另一篇文章:對(duì)FedAvg中模型聚合過(guò)程的理解。

當(dāng)然,這只是一種很簡(jiǎn)單的匯總方式,還有一些其他類(lèi)型的匯總方式。

論文Electricity Consumer Characteristics Identification: A Federated Learning Approach中總結(jié)了三種匯總方式:

normal:原始論文中的方式,即根據(jù)樣本數(shù)量來(lái)決定客戶(hù)端參數(shù)在最終組合時(shí)所占比例。

LA:根據(jù)客戶(hù)端模型的損失占所有客戶(hù)端損失和的比重來(lái)決定最終組合時(shí)參數(shù)所占比例。

LS:根據(jù)損失與樣本數(shù)量的乘積所占的比重來(lái)決定。 將更新后的參數(shù)分發(fā)給被選中的客戶(hù)端

def dispatch(self, index):
     params = {}
     with torch.no_grad():
          for k, v in self.nn.named_parameters():
               params[k] = copy.deepcopy(v)
     for j in index:
          with torch.no_grad():
               for k, v in self.nns[j].named_parameters():
                    v.copy_(params[k])

3. 客戶(hù)端

客戶(hù)端只需要利用本地?cái)?shù)據(jù)來(lái)進(jìn)行更新就行了:

def client_update(self, index):  # update nn
     for k in index:
          self.nns[k] = train(self.nns[k])

其中train():

def train(ann):
    ann.train()
    # print(p)
    if ann.type == 'load':
        Dtr, Dte = nn_seq(ann.name, ann.B, ann.type)
    else:
        Dtr, Dte = nn_seq_wind(ann.named, ann.B, ann.type)
    ann.len = len(Dtr)
    # print(len(Dtr))
    loss_function = nn.MSELoss().to(device)
    loss = 0
    optimizer = torch.optim.Adam(ann.parameters(), lr=ann.lr)
    for epoch in range(ann.E):
        cnt = 0
        for (seq, label) in Dtr:
            cnt += 1
            seq = seq.to(device)
            label = label.to(device)
            y_pred = ann(seq)
            loss = loss_function(y_pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('epoch', epoch, ':', loss.item())
    return ann

4. 測(cè)試

def global_test(self):
     model = self.nn
     model.eval()
     c = clients if self.type == 'load' else clients_wind
     for client in c:
          model.name = client
          test(model)

V. 實(shí)驗(yàn)及結(jié)果

本次實(shí)驗(yàn)的參數(shù)選擇為:

KCEBr
100.550505
if __name__ == '__main__':
    K, C, E, B, r = 10, 0.5, 50, 50, 5
    type = 'load'
    input_dim = 30 if type == 'load' else 28
    _client = clients if type == 'load' else clients_wind
    lr = 0.08
    options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client,
               'input_dim': input_dim, 'lr': lr}
    fedavg = FedAvg(options)
    fedavg.server()
    fedavg.global_test()

各個(gè)客戶(hù)端單獨(dú)訓(xùn)練(訓(xùn)練50輪,batch大小為50)后在本地的測(cè)試集上的表現(xiàn)為:

客戶(hù)端編號(hào)12345678910
MAPE / %5.334.113.034.203.022.702.942.992.304.10

可以看到,由于各個(gè)客戶(hù)端的數(shù)據(jù)都十分充足,所以每個(gè)客戶(hù)端自己訓(xùn)練的本地模型的預(yù)測(cè)精度已經(jīng)很高了。

服務(wù)器與客戶(hù)端通信5輪后,服務(wù)器上的全局模型在10個(gè)客戶(hù)端測(cè)試集上的表現(xiàn)如下所示:

客戶(hù)端編號(hào)12345678910
MAPE / %6.844.543.565.113.754.474.303.903.154.58

可以看到,經(jīng)過(guò)聯(lián)邦學(xué)習(xí)框架得到全局模型在各個(gè)客戶(hù)端上表現(xiàn)同樣很好ÿ0c;這是因?yàn)槭畟€(gè)地區(qū)上的數(shù)據(jù)分布類(lèi)似。

給出numpy和PyTorch的對(duì)比:

客戶(hù)端編號(hào)12345678910
本地5.334.113.034.203.022.702.942.992.304.10
numpy6.584.193.175.133.584.694.713.752.944.77
PyTorch6.844.543.565.113.754.474.303.903.154.58

同樣本地模型的效果是最好的,PyTorch搭建的網(wǎng)絡(luò)和numpy搭建的網(wǎng)絡(luò)效果差不多,但推薦使用PyTorch,不要造輪子。

VI. 源碼及數(shù)據(jù)

我把數(shù)據(jù)和代碼放在了GitHub上:源碼及數(shù)據(jù),原創(chuàng)不易,下載時(shí)請(qǐng)隨手給個(gè)follow和star,感謝!

以上就是PyTorch實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)的基本算法FedAvg的詳細(xì)內(nèi)容,更多關(guān)于PyTorch實(shí)現(xiàn)FedAvg算法的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • 使用python-pptx創(chuàng)建PPT演示文檔功能實(shí)踐

    使用python-pptx創(chuàng)建PPT演示文檔功能實(shí)踐

    這篇文章主要介紹了使用python-pptx創(chuàng)建PPT演示文檔功能實(shí)踐,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-06-06
  • Python常用隊(duì)列全面詳細(xì)梳理

    Python常用隊(duì)列全面詳細(xì)梳理

    隊(duì)列是限制在兩端進(jìn)行插入和操作的線性表,允許存入操作的一段叫“隊(duì)尾”,刪除操作的一端叫“隊(duì)頭”,隊(duì)列的特點(diǎn):隊(duì)列只能在隊(duì)頭和隊(duì)尾進(jìn)行數(shù)據(jù)操作,隊(duì)列模型具有先進(jìn)先出的規(guī)律
    2023-01-01
  • Python PyQt5模塊實(shí)現(xiàn)窗口GUI界面代碼實(shí)例

    Python PyQt5模塊實(shí)現(xiàn)窗口GUI界面代碼實(shí)例

    這篇文章主要介紹了Python PyQt5模塊實(shí)現(xiàn)窗口GUI界面代碼實(shí)例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-05-05
  • python獲取局域網(wǎng)占帶寬最大3個(gè)ip的方法

    python獲取局域網(wǎng)占帶寬最大3個(gè)ip的方法

    這篇文章主要介紹了python獲取局域網(wǎng)占帶寬最大3個(gè)ip的方法,涉及Python解析URL參數(shù)的相關(guān)技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下
    2015-07-07
  • Python+PyQt5自制監(jiān)控小工具

    Python+PyQt5自制監(jiān)控小工具

    這篇文章主要為大家詳細(xì)介紹了如何通過(guò)使用python實(shí)現(xiàn)對(duì)計(jì)算機(jī)攝像頭的調(diào)用從而實(shí)現(xiàn)攝像監(jiān)控的功能,文中的示例代碼講解詳細(xì),需要的可以參考一下
    2023-03-03
  • Python中bytes字節(jié)串和string字符串之間的轉(zhuǎn)換方法

    Python中bytes字節(jié)串和string字符串之間的轉(zhuǎn)換方法

    python中字節(jié)字符串不能格式化,獲取到的網(wǎng)頁(yè)有時(shí)候是字節(jié)字符串,需要轉(zhuǎn)化后再解析,下面這篇文章主要給大家介紹了關(guān)于Python中bytes字節(jié)串和string字符串之間的轉(zhuǎn)換方法,需要的朋友可以參考下
    2022-01-01
  • OpenCV制作Mask圖像掩碼的案例

    OpenCV制作Mask圖像掩碼的案例

    這篇文章主要介紹了OpenCV制作Mask圖像掩碼的案例,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2023-02-02
  • Python實(shí)現(xiàn)享元模式的示例代碼

    Python實(shí)現(xiàn)享元模式的示例代碼

    享元模式是一種結(jié)構(gòu)型設(shè)計(jì)模式,旨在通過(guò)共享盡可能多的相似對(duì)象來(lái)減少內(nèi)存使用,提高性能,下面我們就來(lái)看看如何使用Python實(shí)現(xiàn)享元模式吧
    2024-02-02
  • 用Python中的wxPython實(shí)現(xiàn)最基本的瀏覽器功能

    用Python中的wxPython實(shí)現(xiàn)最基本的瀏覽器功能

    這篇文章主要介紹了用Python中的wxPython實(shí)現(xiàn)基本的瀏覽器功能,本文來(lái)自于IBM官方網(wǎng)站開(kāi)發(fā)者文檔,需要的朋友可以參考下
    2015-04-04
  • Python的缺點(diǎn)和劣勢(shì)分析

    Python的缺點(diǎn)和劣勢(shì)分析

    在本篇文章里小編給大家整理了關(guān)于Python的缺點(diǎn)和劣勢(shì)總結(jié),有興趣的朋友們可以學(xué)習(xí)下。
    2019-11-11

最新評(píng)論