Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
DP-Sparse-Learning/FLModel.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
231 lines (194 sloc)
9.76 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Federated Learning Model in PyTorch | |
import torch | |
from torch import nn | |
from torch.utils.data import Dataset, DataLoader, TensorDataset | |
from utils import gaussian_noise | |
from tensorflow_privacy.privacy.analysis.compute_noise_from_budget_lib import compute_noise | |
import numpy as np | |
import copy | |
class FLClient(nn.Module): | |
""" Client of Federated Learning framework. | |
1. Receive global model from server | |
2. Perform local training (compute gradients) | |
3. Return local model (gradients) to server | |
""" | |
def __init__(self, model, output_size, data, lr, E, batch_size, q, clip, sigma, optimizer, thred_k, device=None): | |
""" | |
:param model: ML model's training process should be implemented | |
:param data: (tuple) dataset, all data in client side is used as training data | |
:param lr: learning rate | |
:param E: epoch of local update | |
""" | |
super(FLClient, self).__init__() | |
self.device = device | |
self.BATCH_SIZE = batch_size | |
self.torch_dataset = TensorDataset(torch.tensor(data[0]), | |
torch.tensor(data[1])) | |
self.data_size = len(self.torch_dataset) | |
self.data_loader = DataLoader( | |
dataset=self.torch_dataset, | |
batch_size=self.BATCH_SIZE, | |
shuffle=True | |
) | |
self.sigma = sigma # DP noise level | |
self.lr = lr | |
self.E = E | |
self.clip = clip | |
self.q = q | |
self.model = model(data[0].shape[1], output_size).to(self.device) | |
self.optimizer = optimizer | |
self.thred_k = thred_k | |
def recv(self, model_param): | |
"""receive global model from aggregator (server)""" | |
self.model.load_state_dict(copy.deepcopy(model_param)) | |
def update(self): | |
"""local model update""" | |
self.model.train() | |
criterion = nn.CrossEntropyLoss(reduction='none') | |
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.0) | |
for e in range(self.E): | |
# randomly select q fraction samples from data | |
# according to the privacy analysis of moments accountant | |
# training "Lots" are sampled by poisson sampling | |
idx = np.where(np.random.rand(len(self.torch_dataset[:][0])) < self.q)[0] | |
sampled_dataset = TensorDataset(self.torch_dataset[idx][0], self.torch_dataset[idx][1]) | |
if self.optimizer == 'Fed_SGD_HT': | |
sample_data_loader = DataLoader( | |
dataset=sampled_dataset, | |
batch_size=self.BATCH_SIZE, | |
shuffle=True | |
) | |
clipped_grads = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()} | |
optimizer.zero_grad() | |
for batch_x, batch_y in sample_data_loader: | |
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) | |
pred_y = self.model(batch_x.float()) | |
loss = criterion(pred_y, batch_y.long()) | |
# bound l2 sensitivity (gradient clipping) | |
# clip each of the gradient in the "Lot" | |
for i in range(loss.size()[0]): | |
loss[i].backward(retain_graph=True) | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip) | |
for name, param in self.model.named_parameters(): | |
clipped_grads[name] += param.grad / len(idx) | |
self.model.zero_grad() | |
# add gaussian noise | |
for name, param in self.model.named_parameters(): | |
clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, self.clip, self.sigma, device=self.device) / len(idx) | |
for name, param in self.model.named_parameters(): | |
param.grad = clipped_grads[name] | |
optimizer.step() | |
if self.optimizer == 'Fed_GD_HT': | |
sample_data_loader = DataLoader( | |
dataset=sampled_dataset, | |
batch_size=self.BATCH_SIZE, | |
shuffle=True | |
) | |
# Reset the data_iter | |
data_iter = iter(sample_data_loader) | |
clipped_grads = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()} | |
for _ in range(len(data_iter)): | |
optimizer.zero_grad() | |
for batch_x, batch_y in next(data_iter): | |
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) | |
pred_y = self.model(batch_x.float()) | |
loss = criterion(pred_y, batch_y.long()) | |
# bound l2 sensitivity (gradient clipping) | |
# clip each of the gradient in the "Lot" | |
for i in range(loss.size()[0]): | |
loss[i].backward(retain_graph=True) | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip) | |
for name, param in self.model.named_parameters(): | |
clipped_grads[name] += param.grad / len(idx) | |
self.model.zero_grad() | |
# add gaussian noise | |
for name, param in self.model.named_parameters(): | |
clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, self.clip, self.sigma, device=self.device) / len(idx) | |
for name, param in self.model.named_parameters(): | |
param.grad = clipped_grads[name] | |
optimizer.step() | |
class FLServer(nn.Module): | |
""" Server of Federated Learning | |
1. Receive model (or gradients) from clients | |
2. Aggregate local models (or gradients) | |
3. Compute global model, broadcast global model to clients | |
""" | |
def __init__(self, fl_param): | |
super(FLServer, self).__init__() | |
self.device = fl_param['device'] | |
self.client_num = fl_param['client_num'] | |
self.C = fl_param['C'] # (float) C in [0, 1] | |
self.clip = fl_param['clip'] | |
self.T = fl_param['tot_T'] # total number of global iterations (communication rounds) | |
self.thred_k = fl_param['thred_k'] | |
self.data = [] | |
self.target = [] | |
for sample in fl_param['data'][self.client_num:]: | |
self.data += [torch.tensor(sample[0]).to(self.device)] # test set | |
self.target += [torch.tensor(sample[1]).to(self.device)] # target label | |
self.input_size = int(self.data[0].shape[1]) | |
self.lr = fl_param['lr'] | |
# compute noise using moments accountant | |
self.sigma = compute_noise(1, fl_param['q'], fl_param['eps'], fl_param['E']*fl_param['tot_T'], fl_param['delta'], 1e-5) | |
self.clients = [FLClient(fl_param['model'], | |
fl_param['output_size'], | |
fl_param['data'][i], | |
fl_param['lr'], | |
fl_param['E'], | |
fl_param['batch_size'], | |
fl_param['q'], | |
fl_param['clip'], | |
self.sigma, | |
fl_param['optimizer'], | |
fl_param['thred_k'], | |
self.device) | |
for i in range(self.client_num)] | |
self.global_model = fl_param['model'](self.input_size, fl_param['output_size']).to(self.device) | |
self.weight = np.array([client.data_size * 1.0 for client in self.clients]) | |
self.broadcast(self.global_model.state_dict()) | |
def aggregated(self, idxs_users): | |
"""FedAvg""" | |
model_par = [self.clients[idx].model.state_dict() for idx in idxs_users] | |
new_par = copy.deepcopy(model_par[0]) | |
for name in new_par: | |
new_par[name] = torch.zeros(new_par[name].shape).to(self.device) | |
for idx, par in enumerate(model_par): | |
w = self.weight[idxs_users[idx]] / np.sum(self.weight[:]) | |
for name in new_par: | |
# new_par[name] += par[name] * (self.weight[idxs_users[idx]] / np.sum(self.weight[idxs_users])) | |
new_par[name] += par[name] * (w / self.C) | |
holder = [] | |
for name in new_par: | |
holder.append(new_par[name]) | |
holder = [ abs(item) for sublist in holder for item in sublist] | |
thred_val = holder.sort()[ -self.thred_k] | |
for name in new_par: | |
new_par[name] = new_par[name]*(new_par[name]>thred_val or new_par[name]< -thred_val) | |
self.global_model.load_state_dict(copy.deepcopy(new_par)) | |
return self.global_model.state_dict().copy() | |
def broadcast(self, new_par): | |
"""Send aggregated model to all clients""" | |
for client in self.clients: | |
client.recv(new_par.copy()) | |
def test_acc(self): | |
self.global_model.eval() | |
correct = 0 | |
tot_sample = 0 | |
for i in range(len(self.data)): | |
t_pred_y = self.global_model(self.data[i]) | |
_, predicted = torch.max(t_pred_y, 1) | |
correct += (predicted == self.target[i]).sum().item() | |
tot_sample += self.target[i].size(0) | |
acc = correct / tot_sample | |
return acc | |
def global_update(self): | |
idxs_users = np.random.choice(range(len(self.clients)), int(self.C * len(self.clients)), replace=False) | |
for idx in idxs_users: | |
self.clients[idx].update() | |
self.broadcast(self.aggregated(idxs_users)) | |
acc = self.test_acc() | |
torch.cuda.empty_cache() | |
return acc | |
def set_lr(self, lr): | |
for c in self.clients: | |
c.lr = lr |