Skip to content
Permalink
main
Switch branches/tags
Go to file
 
 
Cannot retrieve contributors at this time
"""
Implementation of several useful tool
Please copy this script to target path
"""
import numpy as np
import random
import torch
def mnist_noniid(dataset, num_users):
"""
Sample non-I.I.D client data from MNIST dataset
:param dataset:
:param num_users:
:return:
"""
# num_shards, num_imgs = 30, 2000
num_shards = int(num_users*3)
num_imgs = int(60000 / num_shards)
idx_shard = [i for i in range(num_shards)]
dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
idxs = np.arange(num_shards*num_imgs)
labels = dataset.targets.numpy()
# sort labels
idxs_labels = np.vstack((idxs, labels))
idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
idxs = idxs_labels[0,:]
# divide and assign
for i in range(num_users):
rand_set = set(np.random.choice(idx_shard, 3, replace=False))
idx_shard = list(set(idx_shard) - rand_set)
for rand in rand_set:
dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
return dict_users
def gaussian_noise(data_shape, s, sigma, device=None):
"""
Gaussian noise for CDP-FedAVG-LS Algorithm
"""
return torch.normal(0, sigma * s, data_shape).to(device)