diff --git a/gan/gan.py b/gan/gan.py index df39e65..72bbee0 100644 --- a/gan/gan.py +++ b/gan/gan.py @@ -1,2 +1,111 @@ +# Import +import os +# Data Mangement import numpy as np import pandas as pd +from sklearn.cross_validation import train_test_split +import matplotlib.pyplot as plt +# Nuerual Net Building +from keras import layers +from keras.layers import Input, Dense, Dropout, InputLayer, Reshape +from keras.models import Sequential, Model +from keras.optimizers import Adam +from keras.utils.generic_utils import Progbar +# Third party keras gan tool +from keras_adversarial import AdversarialModel, simple_gan, gan_targets +from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling + +# Load data +base_path = "../mkdataset/datasets/gan_datasets/" +# label_file_name = "tham_human_and_mouse_dataset.csv" +file_name = "new.csv" +all_data = pd.read_csv(os.path.join(base_path, file_name)) +# labeled_data = pd.read_csv(os.path.join(label_file_name, file_name)) + + + +# Output +output_base_path = './' + +# Prepare data +x_train = all_data.iloc[:, np.arange(20)] + +# column_start_index_of_genes = 2 +# class_label_column_index = 1 +# features = all_data.iloc[:, np.arange(column_start_index_of_genes, df.shape[1])] +# labels = all_data.iloc[:, class_label_column_index] + +# column_start_index_of_genes = 2 +# features = all_data.iloc[:, np.arange(20)] +# labels = labeled_data.iloc[:, class_label_column_index] + +# x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.33, random_state=42) + + +# Data Variables +input_dimension = x_train.shape[1] # Number of features (e.g. genes) + +gen_input_shape = (1, input_dimension) +discr_input_shape = (1, input_dimension) + +epochs = 10 +batch_size = x_train.shape[0] + +# Build Generative model +generative_model = Sequential() +# generative_model.add(InputLayer(input_shape=gen_input_shape)) +generative_model.add(Dense(units=int(1.2*input_dimension), activation='relu', input_dim=input_dimension)) +generative_model.add(Dropout(rate=0.2, noise_shape=None, seed=15)) +generative_model.add(Dense(units=int(0.2*input_dimension), activation='relu')) +generative_model.add(Dense(units=input_dimension, activation='relu')) +generative_model.add(Reshape(discr_input_shape)) + +# Build Discriminator model +discriminator_model = Sequential() +discriminator_model.add(InputLayer(input_shape=discr_input_shape)) +discriminator_model.add(Dense(units=int(1.2*input_dimension), activation='relu')) +discriminator_model.add(Dropout(rate=0.2, noise_shape=None, seed=75)) +discriminator_model.add(Dense(units=int(0.2*input_dimension), activation='relu')) +discriminator_model.add(Dense(units=1, activation='sigmoid')) + +# Build GAN +gan = simple_gan(generative_model, discriminator_model, normal_latent_sampling((input_dimension, ))) +model = AdversarialModel(base_model=gan, + player_params=[generative_model.trainable_weights, + discriminator_model.trainable_weights], + player_names=['generator', 'discriminator']) +# Other optimizer to try AdversarialOptimizerAlternating +model.adversarial_compile(adversarial_optimizer=AdversarialOptimizerSimultaneous(), + player_optimizers=['adam', 'adam'], loss='binary_crossentropy') + +# Print Summary of Models +generative_model.summary() +discriminator_model.summary() +gan.summary() + +# Train +# gan_targets takes as inputs the # of samples +training_record = model.fit(x=x_train, y=gan_targets(x_train.shape[0]), epochs=epochs, + batch_size=batch_size) + +# Diplay plot of loss over training +# plt.plot(history.history['player_0_loss']) +# plt.plot(history.history['player_1_loss']) +# plt.plot(history.history['loss']) + +# Predict (i.e. produce new samples) +zsamples = np.random.normal(size=(1, input_dimension)) +pred = generator.predict(zsamples) +print(pred) + +# Save new samples to file +# new_samples = pd.DataFrame(pred) +# new_samples.to_csv(os.path.join(output_base_path, 'new_samples.csv')) + +# # save training_record +# df = pd.DataFrame(training_record.history) +# df.to_csv(os.path.join(output_base_path, 'training_record.csv')) +# +# # save models +# generator.save(os.path.join(output_base_path, 'generator.h5')) +# discriminator.save(os.path.join(output_base_path, "discriminator.h5"))