Skip to content

Commit

Permalink
fix training error
Browse files Browse the repository at this point in the history
  • Loading branch information
Qinqing Liu committed Dec 20, 2021
1 parent b9be5f6 commit 6bde8f6
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 24 deletions.
Binary file modified .DS_Store
Binary file not shown.
Binary file modified tensorflow/.DS_Store
Binary file not shown.
9 changes: 5 additions & 4 deletions tensorflow/script/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
_C.SOLVER.logdir = 'logs' # Directory where to write event logs
_C.SOLVER.ckpt = '' # Restore weights from checkpoint file
_C.SOLVER.run = 'train' # Choose from train or test
_C.SOLVER.type = 'adam' # Choose from sgd or adam
_C.SOLVER.type = 'sgd' # Choose from sgd or adam
_C.SOLVER.max_iter = 160000 # Maximum training iterations
_C.SOLVER.test_iter = 100 # Test steps in testing phase
_C.SOLVER.test_every_iter = 4000 # Test model every n training steps
_C.SOLVER.test_every_iter = 1000 # Test model every n training steps
_C.SOLVER.lr_type = 'step' # Learning rate type: step or cos
_C.SOLVER.learning_rate = 0.01 # Initial learning rate
_C.SOLVER.gamma = 0.1 # Learning rate step-wise decay
Expand All @@ -24,11 +24,12 @@
_C.SOLVER.var_name = ('_name',)# Variable names used for finetuning
_C.SOLVER.ignore_var_name = ('_name',)# Ignore variable names when loading ckpt
_C.SOLVER.verbose = False # Whether to output some messages
_C.SOLVER.task = 'reg'
_C.SOLVER.task = 'class'


# DATA related parameters
_C.DATA = CN()
_C.DATA.name = 'CASF'
_C.DATA.train = CN()
_C.DATA.train.dtype = 'points' # The data type: points or octree
_C.DATA.train.x_alias = 'data' # The alias of the data
Expand Down Expand Up @@ -70,7 +71,7 @@
_C.MODEL = CN()
_C.MODEL.name = '' # The name of the model
_C.MODEL.depth = 5 # The input octree depth
_C.MODEL.depth_out = 2 # The output feature depth
_C.MODEL.depth_out = 5 # The output feature depth
_C.MODEL.channel = 3 # The input feature channel
_C.MODEL.factor = 1 # The factor used to widen the network
_C.MODEL.nout = 40 # The output feature channel
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/script/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ def merge_octrees(octrees, labels):

dataset = tf.data.TFRecordDataset(record_names).take(take).repeat()
if shuffle_size > 1: dataset = dataset.shuffle(shuffle_size)
itr = dataset.map(self.parse_example, num_parallel_calls=8) \
.batch(batch_size).map(merge_octrees, num_parallel_calls=8) \
.prefetch(8).make_one_shot_iterator()
itr = dataset.map(self.parse_example, num_parallel_calls=36) \
.batch(batch_size).map(merge_octrees, num_parallel_calls=36) \
.prefetch(36).make_one_shot_iterator()
return itr if return_iter else itr.get_next()

class GridDataset:
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/script/network_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# octree-based resnet55
def network_resnet(octree, flags, training=True, reuse=None):
depth = flags.depth
channels = [2048, 1024, 512, 256, 128, 64, 32, 16, 8]
channels = [2048, 1024, 512, 256, 128, 64, 32, 16, 8, 8, 8]
with tf.variable_scope("ocnn_resnet", reuse=reuse):
data = octree_property(octree, property_name="feature", dtype=tf.float32,
depth=depth, channel=flags.channel)
Expand Down Expand Up @@ -65,7 +65,7 @@ def network_resnet_grids(grids, flags, training=True, reuse=None):
# the ocnn in the paper
def network_ocnn(octree, flags, training=True, reuse=None):
depth = flags.depth
channels = [2048, 1024, 512, 256, 128, 64, 32, 16, 8] #[512, 256, 128, 64, 32, 16, 8, 4, 2] #[2048, 1024, 512, 256, 128, 64, 32, 32, 32]
channels = [2048, 1024, 512, 256, 128, 64, 32, 16, 8, 8, 8] #[512, 256, 128, 64, 32, 16, 8, 4, 2] #[2048, 1024, 512, 256, 128, 64, 32, 32, 32]
with tf.variable_scope("ocnn", reuse=reuse):
data = octree_property(octree, property_name="feature", dtype=tf.float32,
depth=depth, channel=flags.channel)
Expand Down
19 changes: 18 additions & 1 deletion tensorflow/script/network_factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
import tensorflow as tf
from network_cls import network_ocnn, network_resnet, network_cnn_grids, network_resnet_grids
#from network_unet import network_unet
#from network_hrnet import HRNet
# from network_unet_scannet import network_unet34

def cls_network(octree, flags, training, reuse=False):
if flags.name.lower() == 'vgg':
if flags.name.lower() == 'ocnn':
return network_ocnn(octree, flags, training, reuse)
elif flags.name.lower() == 'resnet':
return network_resnet(octree, flags, training, reuse)
elif flags.name.lower() == 'hrnet':
return HRNet(flags).network_cls(octree, training, reuse)
elif flags.name.lower() == 'cnn_grids':
return network_cnn_grids(octree, flags, training, reuse)
elif flags.name.lower() == 'resnet_grids':
return network_resnet_grids(octree, flags, training, reuse)
else:
print('Error, no network: ' + flags.name)

def seg_network(octree, flags, training, reuse=False, pts=None, mask=None):
if flags.name.lower() == 'unet':
return network_unet(octree, flags, training, reuse)
elif flags.name.lower() == 'hrnet':
return HRNet(flags).network_seg(octree, training, reuse, pts, mask)
# elif flags.name.lower() == 'unet_scannet':
# return network_unet34(octree, flags, training, reuse, pts, mask)
else:
print('Error, no network: ' + flags.name)


7 changes: 5 additions & 2 deletions tensorflow/script/test_reg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def check_input(dataset='test', training=False, reuse=False, task = 'class'):
if __name__ == '__main__':
# solver = TFSolver(FLAGS.SOLVER, check_input)
# solver.check_grids()

#print(FLAGS.SOLVER)
solver = TFSolver(FLAGS, get_output)
solver.test_ave()
test_size_dic = {'CASF': 285, 'general_2019': 1146, 'refined_2019': 394, 'decoy':1460, 'training_15241':15241, 'training_15235': 15235}
solver.test_ave(test_size=test_size_dic[FLAGS.DATA.name])

#solver.test_ave()
114 changes: 102 additions & 12 deletions tensorflow/script/tfsolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def train(self):
self.build_train_graph()

# qq: add
self.qq_set_update_after_k_round()
# self.qq_set_update_after_k_round()

# checkpoint
start_iter = 1
Expand Down Expand Up @@ -124,16 +124,16 @@ def train(self):
# training
# qq: revise the training, to update gradients after multiple iterations
# first 2 lines are original code.
# summary, _ = sess.run([self.summ_train, self.train_op])
# summary_writer.add_summary(summary, i)
if i == 0:
sess.run(self.zero_ops)
if i % 10 !=0 or i ==0:
sess.run(self.accum_ops)
else:
sess.run(self.accum_ops)
sess.run(self.train_step)
sess.run(self.zero_ops)
summary, _ = sess.run([self.summ_train, self.train_op])
summary_writer.add_summary(summary, i)
#if i == 0:
# sess.run(self.zero_ops)
#if i % 10 !=0 or i ==0:
# sess.run(self.accum_ops)
#else:
# sess.run(self.accum_ops)
# sess.run(self.train_step)
# sess.run(self.zero_ops)
# qq: end revise

# testing
Expand Down Expand Up @@ -205,7 +205,7 @@ def param_stats(self):
def test(self):
# build graph
self.build_test_graph()
self.qq_set_update_after_k_round()
#self.qq_set_update_after_k_round()

# checkpoint
assert(self.flags.ckpt) # the self.flags.ckpt should be provided
Expand Down Expand Up @@ -276,6 +276,7 @@ def test_ave(self, test_size = 285):
iter_test_result = sess.run(outputs)
test_logits.append(iter_test_result[0])
test_labels.append(iter_test_result[1])
# print(iter_test_result[0], iter_test_result[1])

all_preds = np.array(test_logits).reshape(test_size, -1)
all_labels = np.array(test_labels).reshape(test_size, -1)
Expand All @@ -286,10 +287,20 @@ def test_ave(self, test_size = 285):
all_labels_mean = all_labels.mean(axis=0)
all_preds_mean = all_preds.mean(axis=0)

#all_labels = all_labels.reshape(test_size,-1)
#all_preds = all_preds.reshape(test_size, -1)
#all_labels_mean = all_labels.mean(axis=1)
#all_preds_mean = all_preds.mean(axis=1)
# if abs(all_labels.std(axis=0).sum()) < 1e-4:
# print(all_labels.std(axis=0))
# print(all_labels)

print(all_labels_mean)
#print(all_preds_mean)
import pandas as pd
df = pd.DataFrame({'label': all_labels_mean, 'pred': all_preds_mean})
df.to_csv('pred_label.csv')

def report_reg_metrics(all_labels, all_preds):
from scipy.stats import pearsonr, spearmanr, kendalltau
from sklearn.metrics import roc_curve, auc, mean_squared_error, mean_absolute_error, r2_score
Expand All @@ -311,8 +322,86 @@ def report_scatter(all_labels, all_probs):
sns.regplot(all_labels, all_probs)
plt.show()

def report_cluster_corr(all_labels, all_probs):
import pandas as pd
clusters = pd.read_excel(r'./predicted clusters.xlsx', engine='openpyxl')
clusters = clusters[:285]
clusters.at[15, 'PDB code'] = "1E66"
clusters.at[171, 'PDB code'] = "3E92"
clusters.at[172, 'PDB code'] = "3E93"

with open(r'./points_list_test_reg.txt', "r") as fid:
pred_list = []
for line in fid.readlines():
pred_list.append(line.strip('\n'))

#pred_values = []
#for i in clusters["PDB code"]: # loops through each protein for the respective PDB codes
# for j,item in enumerate(pred_list): # loops through each line of the prediction value text file
# item = item.upper() # changes each line to uppercase because the txt file PDB codes are lowercase and we need them in uppercase
# if item[18:22] == i: # j[18:22] is the PDB code for the pred value. This matches the PDB codes of the prediction and true values of proteins
# x = item[44:] # j[44:] is the prediction value
# x = float(x) # turns predicion value from string to float
# x = all_probs[j]
# pred_values.append(x)
#clusters["pred"] = pred_values # adds a column in the cluster dataframe for predicted values
#print(clusters)

#corr = clusters.groupby('Cluster ID')[['Binding constant','pred']].corr().iloc[0::2,-1]
import matplotlib.pyplot as plt
import seaborn as sns
#plt.figure()
#sns.distplot(corr, kde=False)
#plt.xlabel('Correlation')
#plt.ylabel('Count')
#plt.savefig('./cluster.png')

#mean_df = clusters.groupby('Cluster ID').mean()
#plt.figure()
#sns.regplot(mean_df['Binding constant'], mean_df['pred'])
#plt.xlabel('Cluster Mean Label')
#plt.ylabel('Cluster Mean Pred')
#plt.savefig('./cluster_mean.png')
#print('Inter cluster corr: {}'.format(np.corrcoef(mean_df['Binding constant'], mean_df['pred'])[0,1]))

print("Double Verify")
cluster_list = []
id_list = []
clusters = clusters.set_index('PDB code')
for j, item in enumerate(pred_list):
item = item.upper()
id = item[18:22]
cluster_list.append(clusters.loc[id, 'Cluster ID'])
id_list.append(id)
print(id, all_labels[j], all_probs[j], clusters.loc[id, 'Binding constant'])

new_df = pd.DataFrame({"pred": all_probs, "label": all_labels, "cluster": cluster_list, "id": id_list})
corr = new_df.groupby('cluster')[['label', 'pred']].corr().iloc[0::2, -1]
plt.figure()
sns.distplot(corr, kde=False)
plt.xlabel('Correlation')
plt.ylabel('Count')
plt.savefig('./cluster.png')
print('Corr: {}'.format(list(np.array(corr))))
#print(new_df)
new_df.to_csv('result.csv')

mean_df = new_df.groupby('cluster').mean()
plt.figure()
sns.regplot(mean_df['label'], mean_df['pred'])
plt.xlabel('Cluster Mean Label')
plt.ylabel('Cluster Mean Pred')
plt.savefig('./cluster_mean.png')
print('Inter cluster corr: {}'.format(np.corrcoef(mean_df['label'], mean_df['pred'])[0,1]))

print("<0: ", (corr<0).sum())
print(">0.8: ", (corr>=0.8).sum())
print(">0.9: ", (corr>=0.9).sum())
print('min; ', corr.min())

report_reg_metrics(all_labels_mean, all_preds_mean)
report_scatter(all_labels_mean, all_preds_mean)
report_cluster_corr(all_labels_mean, all_preds_mean)

def check_grids(self, test_size = 285):
# build graph
Expand Down Expand Up @@ -341,3 +430,4 @@ def check_grids(self, test_size = 285):

def run(self):
eval('self.{}()'.format(self.flags.run))

0 comments on commit 6bde8f6

Please sign in to comment.