diff --git a/gan/gan.py b/gan/gan.py index 72bbee0..f1788dd 100644 --- a/gan/gan.py +++ b/gan/gan.py @@ -4,7 +4,7 @@ import os import numpy as np import pandas as pd from sklearn.cross_validation import train_test_split -import matplotlib.pyplot as plt +# import matplotlib.pyplot as plt # Nuerual Net Building from keras import layers from keras.layers import Input, Dense, Dropout, InputLayer, Reshape @@ -45,8 +45,8 @@ x_train = all_data.iloc[:, np.arange(20)] # Data Variables input_dimension = x_train.shape[1] # Number of features (e.g. genes) -gen_input_shape = (1, input_dimension) -discr_input_shape = (1, input_dimension) +gen_input_shape = (input_dimension,) +discr_input_shape = (input_dimension,) epochs = 10 batch_size = x_train.shape[0] @@ -95,7 +95,7 @@ training_record = model.fit(x=x_train, y=gan_targets(x_train.shape[0]), epochs=e # Predict (i.e. produce new samples) zsamples = np.random.normal(size=(1, input_dimension)) -pred = generator.predict(zsamples) +pred = generative_model.predict(zsamples) print(pred) # Save new samples to file