Implementation of several useful tool
Please copy this script to target path
import numpy as np
import random
import torch
def mnist_noniid(dataset, num_users):
Sample non-I.I.D client data from MNIST dataset
:param dataset:
:param num_users:
# num_shards, num_imgs = 30, 2000
num_shards = int(num_users*3)
num_imgs = int(60000 / num_shards)
idx_shard = [i for i in range(num_shards)]
dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
idxs = np.arange(num_shards*num_imgs)
labels = dataset.targets.numpy()
# sort labels
idxs_labels = np.vstack((idxs, labels))
idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
idxs = idxs_labels[0,:]
# divide and assign
for i in range(num_users):
rand_set = set(np.random.choice(idx_shard, 3, replace=False))
idx_shard = list(set(idx_shard) - rand_set)
for rand in rand_set:
dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
return dict_users
def gaussian_noise(data_shape, s, sigma, device=None):
Gaussian noise for CDP-FedAVG-LS Algorithm
return torch.normal(0, sigma * s, data_shape).to(device)