Skip to content
Permalink
957d3c957c
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
231 lines (194 sloc) 9.76 KB
# Federated Learning Model in PyTorch
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from utils import gaussian_noise
from tensorflow_privacy.privacy.analysis.compute_noise_from_budget_lib import compute_noise
import numpy as np
import copy
class FLClient(nn.Module):
""" Client of Federated Learning framework.
1. Receive global model from server
2. Perform local training (compute gradients)
3. Return local model (gradients) to server
"""
def __init__(self, model, output_size, data, lr, E, batch_size, q, clip, sigma, optimizer, thred_k, device=None):
"""
:param model: ML model's training process should be implemented
:param data: (tuple) dataset, all data in client side is used as training data
:param lr: learning rate
:param E: epoch of local update
"""
super(FLClient, self).__init__()
self.device = device
self.BATCH_SIZE = batch_size
self.torch_dataset = TensorDataset(torch.tensor(data[0]),
torch.tensor(data[1]))
self.data_size = len(self.torch_dataset)
self.data_loader = DataLoader(
dataset=self.torch_dataset,
batch_size=self.BATCH_SIZE,
shuffle=True
)
self.sigma = sigma # DP noise level
self.lr = lr
self.E = E
self.clip = clip
self.q = q
self.model = model(data[0].shape[1], output_size).to(self.device)
self.optimizer = optimizer
self.thred_k = thred_k
def recv(self, model_param):
"""receive global model from aggregator (server)"""
self.model.load_state_dict(copy.deepcopy(model_param))
def update(self):
"""local model update"""
self.model.train()
criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.0)
for e in range(self.E):
# randomly select q fraction samples from data
# according to the privacy analysis of moments accountant
# training "Lots" are sampled by poisson sampling
idx = np.where(np.random.rand(len(self.torch_dataset[:][0])) < self.q)[0]
sampled_dataset = TensorDataset(self.torch_dataset[idx][0], self.torch_dataset[idx][1])
if self.optimizer == 'Fed_SGD_HT':
sample_data_loader = DataLoader(
dataset=sampled_dataset,
batch_size=self.BATCH_SIZE,
shuffle=True
)
clipped_grads = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}
optimizer.zero_grad()
for batch_x, batch_y in sample_data_loader:
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
pred_y = self.model(batch_x.float())
loss = criterion(pred_y, batch_y.long())
# bound l2 sensitivity (gradient clipping)
# clip each of the gradient in the "Lot"
for i in range(loss.size()[0]):
loss[i].backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip)
for name, param in self.model.named_parameters():
clipped_grads[name] += param.grad / len(idx)
self.model.zero_grad()
# add gaussian noise
for name, param in self.model.named_parameters():
clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, self.clip, self.sigma, device=self.device) / len(idx)
for name, param in self.model.named_parameters():
param.grad = clipped_grads[name]
optimizer.step()
if self.optimizer == 'Fed_GD_HT':
sample_data_loader = DataLoader(
dataset=sampled_dataset,
batch_size=self.BATCH_SIZE,
shuffle=True
)
# Reset the data_iter
data_iter = iter(sample_data_loader)
clipped_grads = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}
for _ in range(len(data_iter)):
optimizer.zero_grad()
for batch_x, batch_y in next(data_iter):
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
pred_y = self.model(batch_x.float())
loss = criterion(pred_y, batch_y.long())
# bound l2 sensitivity (gradient clipping)
# clip each of the gradient in the "Lot"
for i in range(loss.size()[0]):
loss[i].backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip)
for name, param in self.model.named_parameters():
clipped_grads[name] += param.grad / len(idx)
self.model.zero_grad()
# add gaussian noise
for name, param in self.model.named_parameters():
clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, self.clip, self.sigma, device=self.device) / len(idx)
for name, param in self.model.named_parameters():
param.grad = clipped_grads[name]
optimizer.step()
class FLServer(nn.Module):
""" Server of Federated Learning
1. Receive model (or gradients) from clients
2. Aggregate local models (or gradients)
3. Compute global model, broadcast global model to clients
"""
def __init__(self, fl_param):
super(FLServer, self).__init__()
self.device = fl_param['device']
self.client_num = fl_param['client_num']
self.C = fl_param['C'] # (float) C in [0, 1]
self.clip = fl_param['clip']
self.T = fl_param['tot_T'] # total number of global iterations (communication rounds)
self.thred_k = fl_param['thred_k']
self.data = []
self.target = []
for sample in fl_param['data'][self.client_num:]:
self.data += [torch.tensor(sample[0]).to(self.device)] # test set
self.target += [torch.tensor(sample[1]).to(self.device)] # target label
self.input_size = int(self.data[0].shape[1])
self.lr = fl_param['lr']
# compute noise using moments accountant
self.sigma = compute_noise(1, fl_param['q'], fl_param['eps'], fl_param['E']*fl_param['tot_T'], fl_param['delta'], 1e-5)
self.clients = [FLClient(fl_param['model'],
fl_param['output_size'],
fl_param['data'][i],
fl_param['lr'],
fl_param['E'],
fl_param['batch_size'],
fl_param['q'],
fl_param['clip'],
self.sigma,
fl_param['optimizer'],
fl_param['thred_k'],
self.device)
for i in range(self.client_num)]
self.global_model = fl_param['model'](self.input_size, fl_param['output_size']).to(self.device)
self.weight = np.array([client.data_size * 1.0 for client in self.clients])
self.broadcast(self.global_model.state_dict())
def aggregated(self, idxs_users):
"""FedAvg"""
model_par = [self.clients[idx].model.state_dict() for idx in idxs_users]
new_par = copy.deepcopy(model_par[0])
for name in new_par:
new_par[name] = torch.zeros(new_par[name].shape).to(self.device)
for idx, par in enumerate(model_par):
w = self.weight[idxs_users[idx]] / np.sum(self.weight[:])
for name in new_par:
# new_par[name] += par[name] * (self.weight[idxs_users[idx]] / np.sum(self.weight[idxs_users]))
new_par[name] += par[name] * (w / self.C)
holder = []
for name in new_par:
holder.append(new_par[name])
holder = [ abs(item) for sublist in holder for item in sublist]
thred_val = holder.sort()[ -self.thred_k]
for name in new_par:
new_par[name] = new_par[name]*(new_par[name]>thred_val or new_par[name]< -thred_val)
self.global_model.load_state_dict(copy.deepcopy(new_par))
return self.global_model.state_dict().copy()
def broadcast(self, new_par):
"""Send aggregated model to all clients"""
for client in self.clients:
client.recv(new_par.copy())
def test_acc(self):
self.global_model.eval()
correct = 0
tot_sample = 0
for i in range(len(self.data)):
t_pred_y = self.global_model(self.data[i])
_, predicted = torch.max(t_pred_y, 1)
correct += (predicted == self.target[i]).sum().item()
tot_sample += self.target[i].size(0)
acc = correct / tot_sample
return acc
def global_update(self):
idxs_users = np.random.choice(range(len(self.clients)), int(self.C * len(self.clients)), replace=False)
for idx in idxs_users:
self.clients[idx].update()
self.broadcast(self.aggregated(idxs_users))
acc = self.test_acc()
torch.cuda.empty_cache()
return acc
def set_lr(self, lr):
for c in self.clients:
c.lr = lr