Skip to content
Permalink
Browse files
Add files via upload
  • Loading branch information
jib10001 committed Dec 11, 2021
1 parent 6607204 commit 71d1ecfe42afb815e55ffa91e308f38112ab6d47
Showing with 231 additions and 0 deletions.
  1. +231 −0 FLModel.py
@@ -0,0 +1,231 @@
# Federated Learning Model in PyTorch
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from utils import gaussian_noise
from tensorflow_privacy.privacy.analysis.compute_noise_from_budget_lib import compute_noise

import numpy as np
import copy


class FLClient(nn.Module):
""" Client of Federated Learning framework.
1. Receive global model from server
2. Perform local training (compute gradients)
3. Return local model (gradients) to server
"""
def __init__(self, model, output_size, data, lr, E, batch_size, q, clip, sigma, optimizer, thred_k, device=None):
"""
:param model: ML model's training process should be implemented
:param data: (tuple) dataset, all data in client side is used as training data
:param lr: learning rate
:param E: epoch of local update
"""
super(FLClient, self).__init__()
self.device = device
self.BATCH_SIZE = batch_size
self.torch_dataset = TensorDataset(torch.tensor(data[0]),
torch.tensor(data[1]))
self.data_size = len(self.torch_dataset)
self.data_loader = DataLoader(
dataset=self.torch_dataset,
batch_size=self.BATCH_SIZE,
shuffle=True
)
self.sigma = sigma # DP noise level
self.lr = lr
self.E = E
self.clip = clip
self.q = q
self.model = model(data[0].shape[1], output_size).to(self.device)
self.optimizer = optimizer
self.thred_k = thred_k

def recv(self, model_param):
"""receive global model from aggregator (server)"""
self.model.load_state_dict(copy.deepcopy(model_param))

def update(self):
"""local model update"""
self.model.train()
criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.0)

for e in range(self.E):
# randomly select q fraction samples from data
# according to the privacy analysis of moments accountant
# training "Lots" are sampled by poisson sampling
idx = np.where(np.random.rand(len(self.torch_dataset[:][0])) < self.q)[0]

sampled_dataset = TensorDataset(self.torch_dataset[idx][0], self.torch_dataset[idx][1])

if self.optimizer == 'Fed_SGD_HT':
sample_data_loader = DataLoader(
dataset=sampled_dataset,
batch_size=self.BATCH_SIZE,
shuffle=True
)
clipped_grads = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}
optimizer.zero_grad()
for batch_x, batch_y in sample_data_loader:
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
pred_y = self.model(batch_x.float())
loss = criterion(pred_y, batch_y.long())

# bound l2 sensitivity (gradient clipping)
# clip each of the gradient in the "Lot"
for i in range(loss.size()[0]):
loss[i].backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip)
for name, param in self.model.named_parameters():
clipped_grads[name] += param.grad / len(idx)
self.model.zero_grad()
# add gaussian noise
for name, param in self.model.named_parameters():
clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, self.clip, self.sigma, device=self.device) / len(idx)


for name, param in self.model.named_parameters():
param.grad = clipped_grads[name]
optimizer.step()

if self.optimizer == 'Fed_GD_HT':


sample_data_loader = DataLoader(
dataset=sampled_dataset,
batch_size=self.BATCH_SIZE,
shuffle=True
)

# Reset the data_iter
data_iter = iter(sample_data_loader)


clipped_grads = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}
for _ in range(len(data_iter)):
optimizer.zero_grad()
for batch_x, batch_y in next(data_iter):
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
pred_y = self.model(batch_x.float())
loss = criterion(pred_y, batch_y.long())

# bound l2 sensitivity (gradient clipping)
# clip each of the gradient in the "Lot"
for i in range(loss.size()[0]):
loss[i].backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.clip)
for name, param in self.model.named_parameters():
clipped_grads[name] += param.grad / len(idx)
self.model.zero_grad()

# add gaussian noise
for name, param in self.model.named_parameters():
clipped_grads[name] += gaussian_noise(clipped_grads[name].shape, self.clip, self.sigma, device=self.device) / len(idx)


for name, param in self.model.named_parameters():
param.grad = clipped_grads[name]
optimizer.step()


class FLServer(nn.Module):
""" Server of Federated Learning
1. Receive model (or gradients) from clients
2. Aggregate local models (or gradients)
3. Compute global model, broadcast global model to clients
"""
def __init__(self, fl_param):
super(FLServer, self).__init__()
self.device = fl_param['device']
self.client_num = fl_param['client_num']
self.C = fl_param['C'] # (float) C in [0, 1]
self.clip = fl_param['clip']
self.T = fl_param['tot_T'] # total number of global iterations (communication rounds)
self.thred_k = fl_param['thred_k']

self.data = []
self.target = []
for sample in fl_param['data'][self.client_num:]:
self.data += [torch.tensor(sample[0]).to(self.device)] # test set
self.target += [torch.tensor(sample[1]).to(self.device)] # target label

self.input_size = int(self.data[0].shape[1])
self.lr = fl_param['lr']

# compute noise using moments accountant
self.sigma = compute_noise(1, fl_param['q'], fl_param['eps'], fl_param['E']*fl_param['tot_T'], fl_param['delta'], 1e-5)

self.clients = [FLClient(fl_param['model'],
fl_param['output_size'],
fl_param['data'][i],
fl_param['lr'],
fl_param['E'],
fl_param['batch_size'],
fl_param['q'],
fl_param['clip'],
self.sigma,
fl_param['optimizer'],
fl_param['thred_k'],
self.device)
for i in range(self.client_num)]
self.global_model = fl_param['model'](self.input_size, fl_param['output_size']).to(self.device)
self.weight = np.array([client.data_size * 1.0 for client in self.clients])
self.broadcast(self.global_model.state_dict())

def aggregated(self, idxs_users):
"""FedAvg"""
model_par = [self.clients[idx].model.state_dict() for idx in idxs_users]
new_par = copy.deepcopy(model_par[0])
for name in new_par:
new_par[name] = torch.zeros(new_par[name].shape).to(self.device)
for idx, par in enumerate(model_par):
w = self.weight[idxs_users[idx]] / np.sum(self.weight[:])
for name in new_par:
# new_par[name] += par[name] * (self.weight[idxs_users[idx]] / np.sum(self.weight[idxs_users]))
new_par[name] += par[name] * (w / self.C)

holder = []
for name in new_par:
holder.append(new_par[name])

holder = [ abs(item) for sublist in holder for item in sublist]

thred_val = holder.sort()[ -self.thred_k]
for name in new_par:
new_par[name] = new_par[name]*(new_par[name]>thred_val or new_par[name]< -thred_val)


self.global_model.load_state_dict(copy.deepcopy(new_par))
return self.global_model.state_dict().copy()

def broadcast(self, new_par):
"""Send aggregated model to all clients"""
for client in self.clients:
client.recv(new_par.copy())

def test_acc(self):
self.global_model.eval()
correct = 0
tot_sample = 0
for i in range(len(self.data)):
t_pred_y = self.global_model(self.data[i])
_, predicted = torch.max(t_pred_y, 1)
correct += (predicted == self.target[i]).sum().item()
tot_sample += self.target[i].size(0)
acc = correct / tot_sample
return acc

def global_update(self):
idxs_users = np.random.choice(range(len(self.clients)), int(self.C * len(self.clients)), replace=False)
for idx in idxs_users:
self.clients[idx].update()
self.broadcast(self.aggregated(idxs_users))
acc = self.test_acc()
torch.cuda.empty_cache()
return acc

def set_lr(self, lr):
for c in self.clients:
c.lr = lr

0 comments on commit 71d1ecf

Please sign in to comment.