Skip to content
Switch branches/tags
Go to file
Cannot retrieve contributors at this time
Implementation of several useful tool
Please copy this script to target path
import numpy as np
import random
import torch
def mnist_noniid(dataset, num_users):
Sample non-I.I.D client data from MNIST dataset
:param dataset:
:param num_users:
# num_shards, num_imgs = 30, 2000
num_shards = int(num_users*3)
num_imgs = int(60000 / num_shards)
idx_shard = [i for i in range(num_shards)]
dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
idxs = np.arange(num_shards*num_imgs)
labels = dataset.targets.numpy()
# sort labels
idxs_labels = np.vstack((idxs, labels))
idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
idxs = idxs_labels[0,:]
# divide and assign
for i in range(num_users):
rand_set = set(np.random.choice(idx_shard, 3, replace=False))
idx_shard = list(set(idx_shard) - rand_set)
for rand in rand_set:
dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
return dict_users
def gaussian_noise(data_shape, s, sigma, device=None):
Gaussian noise for CDP-FedAVG-LS Algorithm
return torch.normal(0, sigma * s, data_shape).to(device)