Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
The start of the GAN. It doesn't work in this stage, because of an er…
…ror saying that it's expecting a 3 demonsional input.
  • Loading branch information
rjm11010 committed Dec 15, 2017
1 parent 5106626 commit ec3967e
Showing 1 changed file with 109 additions and 0 deletions.
109 changes: 109 additions & 0 deletions gan/gan.py
@@ -1,2 +1,111 @@
# Import
import os
# Data Mangement
import numpy as np
import pandas as pd
from sklearn.cross_validation import train_test_split
import matplotlib.pyplot as plt
# Nuerual Net Building
from keras import layers
from keras.layers import Input, Dense, Dropout, InputLayer, Reshape
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils.generic_utils import Progbar
# Third party keras gan tool
from keras_adversarial import AdversarialModel, simple_gan, gan_targets
from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling

# Load data
base_path = "../mkdataset/datasets/gan_datasets/"
# label_file_name = "tham_human_and_mouse_dataset.csv"
file_name = "new.csv"
all_data = pd.read_csv(os.path.join(base_path, file_name))
# labeled_data = pd.read_csv(os.path.join(label_file_name, file_name))



# Output
output_base_path = './'

# Prepare data
x_train = all_data.iloc[:, np.arange(20)]

# column_start_index_of_genes = 2
# class_label_column_index = 1
# features = all_data.iloc[:, np.arange(column_start_index_of_genes, df.shape[1])]
# labels = all_data.iloc[:, class_label_column_index]

# column_start_index_of_genes = 2
# features = all_data.iloc[:, np.arange(20)]
# labels = labeled_data.iloc[:, class_label_column_index]

# x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.33, random_state=42)


# Data Variables
input_dimension = x_train.shape[1] # Number of features (e.g. genes)

gen_input_shape = (1, input_dimension)
discr_input_shape = (1, input_dimension)

epochs = 10
batch_size = x_train.shape[0]

# Build Generative model
generative_model = Sequential()
# generative_model.add(InputLayer(input_shape=gen_input_shape))
generative_model.add(Dense(units=int(1.2*input_dimension), activation='relu', input_dim=input_dimension))
generative_model.add(Dropout(rate=0.2, noise_shape=None, seed=15))
generative_model.add(Dense(units=int(0.2*input_dimension), activation='relu'))
generative_model.add(Dense(units=input_dimension, activation='relu'))
generative_model.add(Reshape(discr_input_shape))

# Build Discriminator model
discriminator_model = Sequential()
discriminator_model.add(InputLayer(input_shape=discr_input_shape))
discriminator_model.add(Dense(units=int(1.2*input_dimension), activation='relu'))
discriminator_model.add(Dropout(rate=0.2, noise_shape=None, seed=75))
discriminator_model.add(Dense(units=int(0.2*input_dimension), activation='relu'))
discriminator_model.add(Dense(units=1, activation='sigmoid'))

# Build GAN
gan = simple_gan(generative_model, discriminator_model, normal_latent_sampling((input_dimension, )))
model = AdversarialModel(base_model=gan,
player_params=[generative_model.trainable_weights,
discriminator_model.trainable_weights],
player_names=['generator', 'discriminator'])
# Other optimizer to try AdversarialOptimizerAlternating
model.adversarial_compile(adversarial_optimizer=AdversarialOptimizerSimultaneous(),
player_optimizers=['adam', 'adam'], loss='binary_crossentropy')

# Print Summary of Models
generative_model.summary()
discriminator_model.summary()
gan.summary()

# Train
# gan_targets takes as inputs the # of samples
training_record = model.fit(x=x_train, y=gan_targets(x_train.shape[0]), epochs=epochs,
batch_size=batch_size)

# Diplay plot of loss over training
# plt.plot(history.history['player_0_loss'])
# plt.plot(history.history['player_1_loss'])
# plt.plot(history.history['loss'])

# Predict (i.e. produce new samples)
zsamples = np.random.normal(size=(1, input_dimension))
pred = generator.predict(zsamples)
print(pred)

# Save new samples to file
# new_samples = pd.DataFrame(pred)
# new_samples.to_csv(os.path.join(output_base_path, 'new_samples.csv'))

# # save training_record
# df = pd.DataFrame(training_record.history)
# df.to_csv(os.path.join(output_base_path, 'training_record.csv'))
#
# # save models
# generator.save(os.path.join(output_base_path, 'generator.h5'))
# discriminator.save(os.path.join(output_base_path, "discriminator.h5"))

0 comments on commit ec3967e

Please sign in to comment.