diff --git a/gan/gan.py b/gan/gan.py index 72bbee0..f1788dd 100644 --- a/gan/gan.py +++ b/gan/gan.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd from sklearn.cross_validation import train_test_split -import matplotlib.pyplot as plt +# import matplotlib.pyplot as plt # Nuerual Net Building from keras import layers from keras.layers import Input, Dense, Dropout, InputLayer, Reshape @@ -45,8 +45,8 @@ # Data Variables input_dimension = x_train.shape[1] # Number of features (e.g. genes) -gen_input_shape = (1, input_dimension) -discr_input_shape = (1, input_dimension) +gen_input_shape = (input_dimension,) +discr_input_shape = (input_dimension,) epochs = 10 batch_size = x_train.shape[0] @@ -95,7 +95,7 @@ # Predict (i.e. produce new samples) zsamples = np.random.normal(size=(1, input_dimension)) -pred = generator.predict(zsamples) +pred = generative_model.predict(zsamples) print(pred) # Save new samples to file