PyTorch實現FedProx聯邦學習算法
I. 前言
FedProx的原理請見:FedAvg聯邦學習FedProx異質網絡優(yōu)化實驗總結
聯邦學習中存在多個客戶端,每個客戶端都有自己的數據集,這個數據集他們是不愿意共享的。
數據集為某城市十個地區(qū)的風電功率,我們假設這10個地區(qū)的電力部門不愿意共享自己的數據,但是他們又想得到一個由所有數據統(tǒng)一訓練得到的全局模型。
III. FedProx
算法偽代碼:
1. 模型定義
客戶端的模型為一個簡單的四層神經網絡模型:
# -*- coding:utf-8 -*- """ @Time: 2022/03/03 12:23 @Author: KI @File: model.py @Motto: Hungry And Humble """ from torch import nn class ANN(nn.Module): def __init__(self, args, name): super(ANN, self).__init__() self.name = name self.len = 0 self.loss = 0 self.fc1 = nn.Linear(args.input_dim, 20) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.dropout = nn.Dropout() self.fc2 = nn.Linear(20, 20) self.fc3 = nn.Linear(20, 20) self.fc4 = nn.Linear(20, 1) def forward(self, data): x = self.fc1(data) x = self.sigmoid(x) x = self.fc2(x) x = self.sigmoid(x) x = self.fc3(x) x = self.sigmoid(x) x = self.fc4(x) x = self.sigmoid(x) return x
2. 服務器端
服務器端和FedAvg一致,即重復進行客戶端采樣、參數傳達、參數聚合三個步驟:
# -*- coding:utf-8 -*- """ @Time: 2022/03/03 12:50 @Author: KI @File: server.py @Motto: Hungry And Humble """ import copy import random import numpy as np import torch from model import ANN from client import train, test class FedProx: def __init__(self, args): self.args = args self.nn = ANN(args=self.args, name='server').to(args.device) self.nns = [] for i in range(self.args.K): temp = copy.deepcopy(self.nn) temp.name = self.args.clients[i] self.nns.append(temp) def server(self): for t in range(self.args.r): print('round', t + 1, ':') # sampling m = np.max([int(self.args.C * self.args.K), 1]) index = random.sample(range(0, self.args.K), m) # st # dispatch self.dispatch(index) # local updating self.client_update(index, t) # aggregation self.aggregation(index) return self.nn def aggregation(self, index): s = 0 for j in index: # normal s += self.nns[j].len params = {} for k, v in self.nns[0].named_parameters(): params[k] = torch.zeros_like(v.data) for j in index: for k, v in self.nns[j].named_parameters(): params[k] += v.data * (self.nns[j].len / s) for k, v in self.nn.named_parameters(): v.data = params[k].data.clone() def dispatch(self, index): for j in index: for old_params, new_params in zip(self.nns[j].parameters(), self.nn.parameters()): old_params.data = new_params.data.clone() def client_update(self, index, global_round): # update nn for k in index: self.nns[k] = train(self.args, self.nns[k], self.nn, global_round) def global_test(self): model = self.nn model.eval() for client in self.args.clients: model.name = client test(self.args, model)
3. 客戶端更新
FedProx中客戶端需要優(yōu)化的函數為:
作者在FedAvg損失函數的基礎上,引入了一個proximal term,我們可以稱之為近端項。引入近端項后,客戶端在本地訓練后得到的模型參數 w將不會與初始時的服務器參數wt偏離太多。
對應的代碼為:
def train(args, model, server, global_round): model.train() Dtr, Dte = nn_seq_wind(model.name, args.B) model.len = len(Dtr) global_model = copy.deepcopy(server) if args.weight_decay != 0: lr = args.lr * pow(args.weight_decay, global_round) else: lr = args.lr if args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=args.weight_decay) print('training...') loss_function = nn.MSELoss().to(args.device) loss = 0 for epoch in range(args.E): for (seq, label) in Dtr: seq = seq.to(args.device) label = label.to(args.device) y_pred = model(seq) optimizer.zero_grad() # compute proximal_term proximal_term = 0.0 for w, w_t in zip(model.parameters(), global_model.parameters()): proximal_term += (w - w_t).norm(2) loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term loss.backward() optimizer.step() print('epoch', epoch, ':', loss.item()) return model
我們在原有MSE損失函數的基礎上加上了一個近端項:
for w, w_t in zip(model.parameters(), global_model.parameters()): proximal_term += (w - w_t).norm(2)
然后再反向傳播求梯度,然后優(yōu)化器step更新參數。
原始論文中還提出了一個不精確解的概念:
不過值得注意的是,我并沒有在原始論文的實驗部分找到如何選擇 γ \gamma γ的說明。查了一下資料后發(fā)現是涉及到了近端梯度下降的知識,本文代碼并沒有考慮不精確解,后期可能會補上。
IV. 完整代碼
鏈接:https://pan.baidu.com/s/1hj2EOcqIUmM-C6R1cyjE5Q
提取碼:fghp
項目結構:
其中:
- server.py為服務器端操作。
- client.py為客戶端操作。
- data_process.py為數據處理部分。
- model.py為模型定義文件。
- args.py為參數定義文件。
- main.py為主文件,如想要運行此項目可直接運行:
python main.py
以上就是PyTorch實現FedProx的聯邦學習算法的詳細內容,更多關于PyTorch實現FedProx算法的資料請關注腳本之家其它相關文章!