From 71d1ecfe42afb815e55ffa91e308f38112ab6d47 Mon Sep 17 00:00:00 2001 From: Jinbo Bi Date: Sat, 11 Dec 2021 14:21:20 -0500 Subject: [PATCH] Add files via upload --- FLModel.py | 231 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 FLModel.py diff --git a/FLModel.py b/FLModel.py new file mode 100644 index 0000000..6eb6e06 --- /dev/null +++ b/FLModel.py @@ -0,0 +1,231 @@ +# Federated Learning Model in PyTorch +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader, TensorDataset +from utils import gaussian_noise +from tensorflow_privacy.privacy.analysis.compute_noise_from_budget_lib import compute_noise + +import numpy as np +import copy + + +class FLClient(nn.Module): + """ Client of Federated Learning framework. + 1. Receive global model from server + 2. Perform local training (compute gradients) + 3. Return local model (gradients) to server + """ + def __init__(self, model, output_size, data, lr, E, batch_size, q, clip, sigma, optimizer, thred_k, device=None): + """ + :param model: ML model's training process should be implemented + :param data: (tuple) dataset, all data in client side is used as training data + :param lr: learning rate + :param E: epoch of local update + """ + super(FLClient, self).__init__() + self.device = device + self.BATCH_SIZE = batch_size + self.torch_dataset = TensorDataset(torch.tensor(data[0]), + torch.tensor(data[1])) + self.data_size = len(self.torch_dataset) + self.data_loader = DataLoader( + dataset=self.torch_dataset, + batch_size=self.BATCH_SIZE, + shuffle=True + ) + self.sigma = sigma # DP noise level + self.lr = lr + self.E = E + self.clip = clip + self.q = q + self.model = model(data[0].shape[1], output_size).to(self.device) + self.optimizer = optimizer + self.thred_k = thred_k + + def recv(self, model_param): + """receive global model from aggregator (server)""" + self.model.load_state_dict(copy.deepcopy(model_param)) + + def update(self): + """local model update""" + self.model.train() + criterion = nn.CrossEntropyLoss(reduction='none') + optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.0) + + for e in range(self.E): + # randomly select q fraction samples from data + # according to the privacy analysis of moments accountant + # training "Lots" are sampled by poisson sampling + idx = np.where(np.random.rand(len(self.torch_dataset[:][0])) < self.q)[0] + + sampled_dataset = TensorDataset(self.torch_dataset[idx][0], self.torch_dataset[idx][1]) + + if self.optimizer == 'Fed_SGD_HT': + sample_data_loader = DataLoader( + dataset=sampled_dataset, + batch_size=self.BATCH_SIZE, + shuffle=True + ) + clipped_grads = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()} + optimizer.zero_grad() + for batch_x, batch_y in sample_data_loader: + batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) + pred_y = self.model(batch_x.float()) + loss = criterion(pred_y, batch_y.long()) + + # bound l2 sensitivity (gradient clipping) + # clip each of the gradient in the "Lot" + for i in range(loss.size()[0]): + loss[i].backward(retain_graph=True) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip) + for name, param in self.model.named_parameters(): + clipped_grads[name] += param.grad / len(idx) + self.model.zero_grad() + # add gaussian noise + for name, param in self.model.named_parameters(): + clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, self.clip, self.sigma, device=self.device) / len(idx) + + + for name, param in self.model.named_parameters(): + param.grad = clipped_grads[name] + optimizer.step() + + if self.optimizer == 'Fed_GD_HT': + + + sample_data_loader = DataLoader( + dataset=sampled_dataset, + batch_size=self.BATCH_SIZE, + shuffle=True + ) + + # Reset the data_iter + data_iter = iter(sample_data_loader) + + + clipped_grads = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()} + for _ in range(len(data_iter)): + optimizer.zero_grad() + for batch_x, batch_y in next(data_iter): + batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) + pred_y = self.model(batch_x.float()) + loss = criterion(pred_y, batch_y.long()) + + # bound l2 sensitivity (gradient clipping) + # clip each of the gradient in the "Lot" + for i in range(loss.size()[0]): + loss[i].backward(retain_graph=True) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip) + for name, param in self.model.named_parameters(): + clipped_grads[name] += param.grad / len(idx) + self.model.zero_grad() + + # add gaussian noise + for name, param in self.model.named_parameters(): + clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, self.clip, self.sigma, device=self.device) / len(idx) + + + for name, param in self.model.named_parameters(): + param.grad = clipped_grads[name] + optimizer.step() + + +class FLServer(nn.Module): + """ Server of Federated Learning + 1. Receive model (or gradients) from clients + 2. Aggregate local models (or gradients) + 3. Compute global model, broadcast global model to clients + """ + def __init__(self, fl_param): + super(FLServer, self).__init__() + self.device = fl_param['device'] + self.client_num = fl_param['client_num'] + self.C = fl_param['C'] # (float) C in [0, 1] + self.clip = fl_param['clip'] + self.T = fl_param['tot_T'] # total number of global iterations (communication rounds) + self.thred_k = fl_param['thred_k'] + + self.data = [] + self.target = [] + for sample in fl_param['data'][self.client_num:]: + self.data += [torch.tensor(sample[0]).to(self.device)] # test set + self.target += [torch.tensor(sample[1]).to(self.device)] # target label + + self.input_size = int(self.data[0].shape[1]) + self.lr = fl_param['lr'] + + # compute noise using moments accountant + self.sigma = compute_noise(1, fl_param['q'], fl_param['eps'], fl_param['E']*fl_param['tot_T'], fl_param['delta'], 1e-5) + + self.clients = [FLClient(fl_param['model'], + fl_param['output_size'], + fl_param['data'][i], + fl_param['lr'], + fl_param['E'], + fl_param['batch_size'], + fl_param['q'], + fl_param['clip'], + self.sigma, + fl_param['optimizer'], + fl_param['thred_k'], + self.device) + for i in range(self.client_num)] + self.global_model = fl_param['model'](self.input_size, fl_param['output_size']).to(self.device) + self.weight = np.array([client.data_size * 1.0 for client in self.clients]) + self.broadcast(self.global_model.state_dict()) + + def aggregated(self, idxs_users): + """FedAvg""" + model_par = [self.clients[idx].model.state_dict() for idx in idxs_users] + new_par = copy.deepcopy(model_par[0]) + for name in new_par: + new_par[name] = torch.zeros(new_par[name].shape).to(self.device) + for idx, par in enumerate(model_par): + w = self.weight[idxs_users[idx]] / np.sum(self.weight[:]) + for name in new_par: + # new_par[name] += par[name] * (self.weight[idxs_users[idx]] / np.sum(self.weight[idxs_users])) + new_par[name] += par[name] * (w / self.C) + + holder = [] + for name in new_par: + holder.append(new_par[name]) + + holder = [ abs(item) for sublist in holder for item in sublist] + + thred_val = holder.sort()[ -self.thred_k] + for name in new_par: + new_par[name] = new_par[name]*(new_par[name]>thred_val or new_par[name]< -thred_val) + + + self.global_model.load_state_dict(copy.deepcopy(new_par)) + return self.global_model.state_dict().copy() + + def broadcast(self, new_par): + """Send aggregated model to all clients""" + for client in self.clients: + client.recv(new_par.copy()) + + def test_acc(self): + self.global_model.eval() + correct = 0 + tot_sample = 0 + for i in range(len(self.data)): + t_pred_y = self.global_model(self.data[i]) + _, predicted = torch.max(t_pred_y, 1) + correct += (predicted == self.target[i]).sum().item() + tot_sample += self.target[i].size(0) + acc = correct / tot_sample + return acc + + def global_update(self): + idxs_users = np.random.choice(range(len(self.clients)), int(self.C * len(self.clients)), replace=False) + for idx in idxs_users: + self.clients[idx].update() + self.broadcast(self.aggregated(idxs_users)) + acc = self.test_acc() + torch.cuda.empty_cache() + return acc + + def set_lr(self, lr): + for c in self.clients: + c.lr = lr