diff --git a/test.py b/test.py new file mode 100644 index 0000000..6eb5aff --- /dev/null +++ b/test.py @@ -0,0 +1,91 @@ +# Application of FL task + +from MLModel import * +from FLModel import * +from utils import * + +from torchvision import datasets, transforms +import torch +import numpy as np +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + + +def choose_two_digit(split_data_lst): + available_digit = [] + for i, digit in enumerate(split_data_lst): + if len(digit) > 0: + available_digit.append(i) + try: + lst = np.random.choice(available_digit, 2, replace=False).tolist() + except: + print(available_digit) + return lst + + + +def load_cnn_mnist(num_users): + data_train = datasets.MNIST(root="~/data/", train=True, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])) + + data_test = datasets.MNIST(root="~/data/", train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])) + + # split MNIST (training set) into non-iid data sets + non_iid = [] + user_dict = mnist_noniid(data_train, num_users) + for i in range(num_users): + idx = user_dict[i] + d = data_train.data[idx].float().unsqueeze(1) + targets = data_train.targets[idx].float() + non_iid.append((d, targets)) + non_iid.append((data_test.data.float().unsqueeze(1), data_test.targets.float())) + return non_iid + +""" +1. load_data +2. generate clients (step 3) +3. generate aggregator +4. training +""" +client_num = 10 +d = load_cnn_mnist(client_num) + +lr = 0.01 +fl_param = { + 'output_size': 10, + 'client_num': client_num, + 'model': NeuralNet, + 'data': d, + 'lr': lr, + 'E': 1, + 'C': 1, + 'eps': 4.0, + 'delta': 1e-5, + 'q': 0.03, + 'clip': 32, + 'tot_T': 150, + 'batch_size': 128, + 'optimizer': 'Fed_GD_HT', #'Fed_SGD_HT' or 'Fed_GD_HT' + 'thred_k': 10000, + 'device': device +} +import warnings +warnings.filterwarnings("ignore") +fl_entity = FLServer(fl_param).to(device) + +print("mnist") +acc = [] +for e in range(150): + #if e+1 % 10 == 0: + # lr *= 0.1 + fl_entity.set_lr(lr/np.sqrt(e+1)) + acc += [fl_entity.global_update()] + print("global epochs = {:d}, acc = {:.4f}".format(e+1, acc[-1])) \ No newline at end of file