From 2da9ce379c940dbac3bfb32222c7c327cb4a431c Mon Sep 17 00:00:00 2001 From: Jinbo Bi Date: Sat, 11 Dec 2021 14:22:07 -0500 Subject: [PATCH] Add files via upload --- utils.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 utils.py diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..6588990 --- /dev/null +++ b/utils.py @@ -0,0 +1,43 @@ +""" +Implementation of several useful tool +Please copy this script to target path +""" +import numpy as np +import random +import torch + + +def mnist_noniid(dataset, num_users): + """ + Sample non-I.I.D client data from MNIST dataset + :param dataset: + :param num_users: + :return: + """ + # num_shards, num_imgs = 30, 2000 + num_shards = int(num_users*3) + num_imgs = int(60000 / num_shards) + idx_shard = [i for i in range(num_shards)] + dict_users = {i: np.array([], dtype='int64') for i in range(num_users)} + idxs = np.arange(num_shards*num_imgs) + labels = dataset.targets.numpy() + + # sort labels + idxs_labels = np.vstack((idxs, labels)) + idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()] + idxs = idxs_labels[0,:] + + # divide and assign + for i in range(num_users): + rand_set = set(np.random.choice(idx_shard, 3, replace=False)) + idx_shard = list(set(idx_shard) - rand_set) + for rand in rand_set: + dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) + return dict_users + + +def gaussian_noise(data_shape, s, sigma, device=None): + """ + Gaussian noise for CDP-FedAVG-LS Algorithm + """ + return torch.normal(0, sigma * s, data_shape).to(device) \ No newline at end of file