PyTorch實(shí)現(xiàn)FedProx聯(lián)邦學(xué)習(xí)算法
I. 前言
FedProx的原理請(qǐng)見(jiàn):FedAvg聯(lián)邦學(xué)習(xí)FedProx異質(zhì)網(wǎng)絡(luò)優(yōu)化實(shí)驗(yàn)總結(jié)
聯(lián)邦學(xué)習(xí)中存在多個(gè)客戶端,每個(gè)客戶端都有自己的數(shù)據(jù)集,這個(gè)數(shù)據(jù)集他們是不愿意共享的。
數(shù)據(jù)集為某城市十個(gè)地區(qū)的風(fēng)電功率,我們假設(shè)這10個(gè)地區(qū)的電力部門(mén)不愿意共享自己的數(shù)據(jù),但是他們又想得到一個(gè)由所有數(shù)據(jù)統(tǒng)一訓(xùn)練得到的全局模型。
III. FedProx
算法偽代碼:
1. 模型定義
客戶端的模型為一個(gè)簡(jiǎn)單的四層神經(jīng)網(wǎng)絡(luò)模型:
# -*- coding:utf-8 -*- """ @Time: 2022/03/03 12:23 @Author: KI @File: model.py @Motto: Hungry And Humble """ from torch import nn class ANN(nn.Module): def __init__(self, args, name): super(ANN, self).__init__() self.name = name self.len = 0 self.loss = 0 self.fc1 = nn.Linear(args.input_dim, 20) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.dropout = nn.Dropout() self.fc2 = nn.Linear(20, 20) self.fc3 = nn.Linear(20, 20) self.fc4 = nn.Linear(20, 1) def forward(self, data): x = self.fc1(data) x = self.sigmoid(x) x = self.fc2(x) x = self.sigmoid(x) x = self.fc3(x) x = self.sigmoid(x) x = self.fc4(x) x = self.sigmoid(x) return x
2. 服務(wù)器端
服務(wù)器端和FedAvg一致,即重復(fù)進(jìn)行客戶端采樣、參數(shù)傳達(dá)、參數(shù)聚合三個(gè)步驟:
# -*- coding:utf-8 -*- """ @Time: 2022/03/03 12:50 @Author: KI @File: server.py @Motto: Hungry And Humble """ import copy import random import numpy as np import torch from model import ANN from client import train, test class FedProx: def __init__(self, args): self.args = args self.nn = ANN(args=self.args, name='server').to(args.device) self.nns = [] for i in range(self.args.K): temp = copy.deepcopy(self.nn) temp.name = self.args.clients[i] self.nns.append(temp) def server(self): for t in range(self.args.r): print('round', t + 1, ':') # sampling m = np.max([int(self.args.C * self.args.K), 1]) index = random.sample(range(0, self.args.K), m) # st # dispatch self.dispatch(index) # local updating self.client_update(index, t) # aggregation self.aggregation(index) return self.nn def aggregation(self, index): s = 0 for j in index: # normal s += self.nns[j].len params = {} for k, v in self.nns[0].named_parameters(): params[k] = torch.zeros_like(v.data) for j in index: for k, v in self.nns[j].named_parameters(): params[k] += v.data * (self.nns[j].len / s) for k, v in self.nn.named_parameters(): v.data = params[k].data.clone() def dispatch(self, index): for j in index: for old_params, new_params in zip(self.nns[j].parameters(), self.nn.parameters()): old_params.data = new_params.data.clone() def client_update(self, index, global_round): # update nn for k in index: self.nns[k] = train(self.args, self.nns[k], self.nn, global_round) def global_test(self): model = self.nn model.eval() for client in self.args.clients: model.name = client test(self.args, model)
3. 客戶端更新
FedProx中客戶端需要優(yōu)化的函數(shù)為:
作者在FedAvg損失函數(shù)的基礎(chǔ)上,引入了一個(gè)proximal term,我們可以稱之為近端項(xiàng)。引入近端項(xiàng)后,客戶端在本地訓(xùn)練后得到的模型參數(shù) w將不會(huì)與初始時(shí)的服務(wù)器參數(shù)wt偏離太多。
對(duì)應(yīng)的代碼為:
def train(args, model, server, global_round): model.train() Dtr, Dte = nn_seq_wind(model.name, args.B) model.len = len(Dtr) global_model = copy.deepcopy(server) if args.weight_decay != 0: lr = args.lr * pow(args.weight_decay, global_round) else: lr = args.lr if args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=args.weight_decay) print('training...') loss_function = nn.MSELoss().to(args.device) loss = 0 for epoch in range(args.E): for (seq, label) in Dtr: seq = seq.to(args.device) label = label.to(args.device) y_pred = model(seq) optimizer.zero_grad() # compute proximal_term proximal_term = 0.0 for w, w_t in zip(model.parameters(), global_model.parameters()): proximal_term += (w - w_t).norm(2) loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term loss.backward() optimizer.step() print('epoch', epoch, ':', loss.item()) return model
我們?cè)谠蠱SE損失函數(shù)的基礎(chǔ)上加上了一個(gè)近端項(xiàng):
for w, w_t in zip(model.parameters(), global_model.parameters()): proximal_term += (w - w_t).norm(2)
然后再反向傳播求梯度,然后優(yōu)化器step更新參數(shù)。
原始論文中還提出了一個(gè)不精確解的概念:
不過(guò)值得注意的是,我并沒(méi)有在原始論文的實(shí)驗(yàn)部分找到如何選擇 γ \gamma γ的說(shuō)明。查了一下資料后發(fā)現(xiàn)是涉及到了近端梯度下降的知識(shí),本文代碼并沒(méi)有考慮不精確解,后期可能會(huì)補(bǔ)上。
IV. 完整代碼
鏈接:https://pan.baidu.com/s/1hj2EOcqIUmM-C6R1cyjE5Q
提取碼:fghp
項(xiàng)目結(jié)構(gòu):
其中:
- server.py為服務(wù)器端操作。
- client.py為客戶端操作。
- data_process.py為數(shù)據(jù)處理部分。
- model.py為模型定義文件。
- args.py為參數(shù)定義文件。
- main.py為主文件,如想要運(yùn)行此項(xiàng)目可直接運(yùn)行:
python main.py
以上就是PyTorch實(shí)現(xiàn)FedProx的聯(lián)邦學(xué)習(xí)算法的詳細(xì)內(nèi)容,更多關(guān)于PyTorch實(shí)現(xiàn)FedProx算法的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
淺談python中的實(shí)例方法、類方法和靜態(tài)方法
本文主要介紹了python 實(shí)例方法、類方法和靜態(tài)方法的相關(guān)知識(shí)。具有很好的參考價(jià)值,下面跟著小編一起來(lái)看下吧2017-02-02詳解python關(guān)于多級(jí)包之間的引用問(wèn)題
本文主要介紹了python關(guān)于多級(jí)包之間的引用問(wèn)題,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-08-08python實(shí)現(xiàn)簡(jiǎn)易學(xué)生信息管理系統(tǒng)
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)簡(jiǎn)易學(xué)生信息管理系統(tǒng),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-09-09Python中處理字符串之islower()方法的使用簡(jiǎn)介
這篇文章主要介紹了Python中處理字符串之islower()方法的使用,是Python入門(mén)的基礎(chǔ)知識(shí),需要的朋友可以參考下2015-05-05Python異步處理返回進(jìn)度——使用Flask實(shí)現(xiàn)進(jìn)度條
這篇文章主要介紹了Python異步處理返回進(jìn)度——使用Flask實(shí)現(xiàn)進(jìn)度條,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-05-05Python獲取網(wǎng)頁(yè)上圖片下載地址的方法
這篇文章主要介紹了Python獲取網(wǎng)頁(yè)上圖片下載地址的方法,涉及Python操作正則表達(dá)式匹配字符串的技巧,需要的朋友可以參考下2015-03-03