Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Add files via upload
  • Loading branch information
jib10001 committed Dec 11, 2021
1 parent d469dcb commit d4aacf7
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions test.py
@@ -0,0 +1,91 @@
# Application of FL task

from MLModel import *
from FLModel import *
from utils import *

from torchvision import datasets, transforms
import torch
import numpy as np
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



def choose_two_digit(split_data_lst):
available_digit = []
for i, digit in enumerate(split_data_lst):
if len(digit) > 0:
available_digit.append(i)
try:
lst = np.random.choice(available_digit, 2, replace=False).tolist()
except:
print(available_digit)
return lst



def load_cnn_mnist(num_users):
data_train = datasets.MNIST(root="~/data/", train=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))

data_test = datasets.MNIST(root="~/data/", train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))

# split MNIST (training set) into non-iid data sets
non_iid = []
user_dict = mnist_noniid(data_train, num_users)
for i in range(num_users):
idx = user_dict[i]
d = data_train.data[idx].float().unsqueeze(1)
targets = data_train.targets[idx].float()
non_iid.append((d, targets))
non_iid.append((data_test.data.float().unsqueeze(1), data_test.targets.float()))
return non_iid

"""
1. load_data
2. generate clients (step 3)
3. generate aggregator
4. training
"""
client_num = 10
d = load_cnn_mnist(client_num)

lr = 0.01
fl_param = {
'output_size': 10,
'client_num': client_num,
'model': NeuralNet,
'data': d,
'lr': lr,
'E': 1,
'C': 1,
'eps': 4.0,
'delta': 1e-5,
'q': 0.03,
'clip': 32,
'tot_T': 150,
'batch_size': 128,
'optimizer': 'Fed_GD_HT', #'Fed_SGD_HT' or 'Fed_GD_HT'
'thred_k': 10000,
'device': device
}
import warnings
warnings.filterwarnings("ignore")
fl_entity = FLServer(fl_param).to(device)

print("mnist")
acc = []
for e in range(150):
#if e+1 % 10 == 0:
# lr *= 0.1
fl_entity.set_lr(lr/np.sqrt(e+1))
acc += [fl_entity.global_update()]
print("global epochs = {:d}, acc = {:.4f}".format(e+1, acc[-1]))

0 comments on commit d4aacf7

Please sign in to comment.