From 6bde8f68ea5df108e8fe3c180b8ec493e95b971f Mon Sep 17 00:00:00 2001 From: Qinqing Liu Date: Sun, 19 Dec 2021 20:42:39 -0500 Subject: [PATCH] fix training error --- .DS_Store | Bin 10244 -> 10244 bytes tensorflow/.DS_Store | Bin 10244 -> 10244 bytes tensorflow/script/config.py | 9 ++- tensorflow/script/dataset.py | 6 +- tensorflow/script/network_cls.py | 4 +- tensorflow/script/network_factory.py | 19 ++++- tensorflow/script/test_reg_model.py | 7 +- tensorflow/script/tfsolver.py | 114 ++++++++++++++++++++++++--- 8 files changed, 135 insertions(+), 24 deletions(-) diff --git a/.DS_Store b/.DS_Store index 6d64b5fecb68f528f2bd6bd0871fb22834317563..a0fb2999cb857e1a519a9a84a6833bdbbdf49c47 100644 GIT binary patch delta 142 zcmZn(XbG6$UDU^hRb;$$8H$vA5ZO$G)A76v_rbcRfZlH7b3m!zEhB%l~axzMaY zleoi;TP)B+38) delta 131 zcmZn(XbG6$¥U^hRb!ekx+$vCTr<_ruBEDU-K=?s|+CAs-7E=f80NkB1<_-9)6 v0@n{YqROY>l`qIJ3{K9^Edc6aUO_W%F@ delta 48 zcmV-00MGw~P=rvBPXQjWP`eKS9+M0ZG?P^jvy 1: dataset = dataset.shuffle(shuffle_size) - itr = dataset.map(self.parse_example, num_parallel_calls=8) \ - .batch(batch_size).map(merge_octrees, num_parallel_calls=8) \ - .prefetch(8).make_one_shot_iterator() + itr = dataset.map(self.parse_example, num_parallel_calls=36) \ + .batch(batch_size).map(merge_octrees, num_parallel_calls=36) \ + .prefetch(36).make_one_shot_iterator() return itr if return_iter else itr.get_next() class GridDataset: diff --git a/tensorflow/script/network_cls.py b/tensorflow/script/network_cls.py index 8bf7279..918a69c 100644 --- a/tensorflow/script/network_cls.py +++ b/tensorflow/script/network_cls.py @@ -5,7 +5,7 @@ from ocnn import * # octree-based resnet55 def network_resnet(octree, flags, training=True, reuse=None): depth = flags.depth - channels = [2048, 1024, 512, 256, 128, 64, 32, 16, 8] + channels = [2048, 1024, 512, 256, 128, 64, 32, 16, 8, 8, 8] with tf.variable_scope("ocnn_resnet", reuse=reuse): data = octree_property(octree, property_name="feature", dtype=tf.float32, depth=depth, channel=flags.channel) @@ -65,7 +65,7 @@ def network_resnet_grids(grids, flags, training=True, reuse=None): # the ocnn in the paper def network_ocnn(octree, flags, training=True, reuse=None): depth = flags.depth - channels = [2048, 1024, 512, 256, 128, 64, 32, 16, 8] #[512, 256, 128, 64, 32, 16, 8, 4, 2] #[2048, 1024, 512, 256, 128, 64, 32, 32, 32] + channels = [2048, 1024, 512, 256, 128, 64, 32, 16, 8, 8, 8] #[512, 256, 128, 64, 32, 16, 8, 4, 2] #[2048, 1024, 512, 256, 128, 64, 32, 32, 32] with tf.variable_scope("ocnn", reuse=reuse): data = octree_property(octree, property_name="feature", dtype=tf.float32, depth=depth, channel=flags.channel) diff --git a/tensorflow/script/network_factory.py b/tensorflow/script/network_factory.py index f7b1b84..74e1164 100644 --- a/tensorflow/script/network_factory.py +++ b/tensorflow/script/network_factory.py @@ -1,14 +1,31 @@ import tensorflow as tf from network_cls import network_ocnn, network_resnet, network_cnn_grids, network_resnet_grids +#from network_unet import network_unet +#from network_hrnet import HRNet +# from network_unet_scannet import network_unet34 def cls_network(octree, flags, training, reuse=False): - if flags.name.lower() == 'vgg': + if flags.name.lower() == 'ocnn': return network_ocnn(octree, flags, training, reuse) elif flags.name.lower() == 'resnet': return network_resnet(octree, flags, training, reuse) + elif flags.name.lower() == 'hrnet': + return HRNet(flags).network_cls(octree, training, reuse) elif flags.name.lower() == 'cnn_grids': return network_cnn_grids(octree, flags, training, reuse) elif flags.name.lower() == 'resnet_grids': return network_resnet_grids(octree, flags, training, reuse) else: print('Error, no network: ' + flags.name) + +def seg_network(octree, flags, training, reuse=False, pts=None, mask=None): + if flags.name.lower() == 'unet': + return network_unet(octree, flags, training, reuse) + elif flags.name.lower() == 'hrnet': + return HRNet(flags).network_seg(octree, training, reuse, pts, mask) + # elif flags.name.lower() == 'unet_scannet': + # return network_unet34(octree, flags, training, reuse, pts, mask) + else: + print('Error, no network: ' + flags.name) + + diff --git a/tensorflow/script/test_reg_model.py b/tensorflow/script/test_reg_model.py index 96b2ea0..ea9432e 100644 --- a/tensorflow/script/test_reg_model.py +++ b/tensorflow/script/test_reg_model.py @@ -27,6 +27,9 @@ def check_input(dataset='test', training=False, reuse=False, task = 'class'): if __name__ == '__main__': # solver = TFSolver(FLAGS.SOLVER, check_input) # solver.check_grids() - + #print(FLAGS.SOLVER) solver = TFSolver(FLAGS, get_output) - solver.test_ave() \ No newline at end of file + test_size_dic = {'CASF': 285, 'general_2019': 1146, 'refined_2019': 394, 'decoy':1460, 'training_15241':15241, 'training_15235': 15235} + solver.test_ave(test_size=test_size_dic[FLAGS.DATA.name]) + + #solver.test_ave() diff --git a/tensorflow/script/tfsolver.py b/tensorflow/script/tfsolver.py index 6c293be..d1db296 100644 --- a/tensorflow/script/tfsolver.py +++ b/tensorflow/script/tfsolver.py @@ -96,7 +96,7 @@ class TFSolver: self.build_train_graph() # qq: add - self.qq_set_update_after_k_round() + # self.qq_set_update_after_k_round() # checkpoint start_iter = 1 @@ -124,16 +124,16 @@ class TFSolver: # training # qq: revise the training, to update gradients after multiple iterations # first 2 lines are original code. - # summary, _ = sess.run([self.summ_train, self.train_op]) - # summary_writer.add_summary(summary, i) - if i == 0: - sess.run(self.zero_ops) - if i % 10 !=0 or i ==0: - sess.run(self.accum_ops) - else: - sess.run(self.accum_ops) - sess.run(self.train_step) - sess.run(self.zero_ops) + summary, _ = sess.run([self.summ_train, self.train_op]) + summary_writer.add_summary(summary, i) + #if i == 0: + # sess.run(self.zero_ops) + #if i % 10 !=0 or i ==0: + # sess.run(self.accum_ops) + #else: + # sess.run(self.accum_ops) + # sess.run(self.train_step) + # sess.run(self.zero_ops) # qq: end revise # testing @@ -205,7 +205,7 @@ class TFSolver: def test(self): # build graph self.build_test_graph() - self.qq_set_update_after_k_round() + #self.qq_set_update_after_k_round() # checkpoint assert(self.flags.ckpt) # the self.flags.ckpt should be provided @@ -276,6 +276,7 @@ class TFSolver: iter_test_result = sess.run(outputs) test_logits.append(iter_test_result[0]) test_labels.append(iter_test_result[1]) + # print(iter_test_result[0], iter_test_result[1]) all_preds = np.array(test_logits).reshape(test_size, -1) all_labels = np.array(test_labels).reshape(test_size, -1) @@ -286,10 +287,20 @@ class TFSolver: all_labels_mean = all_labels.mean(axis=0) all_preds_mean = all_preds.mean(axis=0) + #all_labels = all_labels.reshape(test_size,-1) + #all_preds = all_preds.reshape(test_size, -1) + #all_labels_mean = all_labels.mean(axis=1) + #all_preds_mean = all_preds.mean(axis=1) # if abs(all_labels.std(axis=0).sum()) < 1e-4: # print(all_labels.std(axis=0)) # print(all_labels) + print(all_labels_mean) + #print(all_preds_mean) + import pandas as pd + df = pd.DataFrame({'label': all_labels_mean, 'pred': all_preds_mean}) + df.to_csv('pred_label.csv') + def report_reg_metrics(all_labels, all_preds): from scipy.stats import pearsonr, spearmanr, kendalltau from sklearn.metrics import roc_curve, auc, mean_squared_error, mean_absolute_error, r2_score @@ -311,8 +322,86 @@ class TFSolver: sns.regplot(all_labels, all_probs) plt.show() + def report_cluster_corr(all_labels, all_probs): + import pandas as pd + clusters = pd.read_excel(r'./predicted clusters.xlsx', engine='openpyxl') + clusters = clusters[:285] + clusters.at[15, 'PDB code'] = "1E66" + clusters.at[171, 'PDB code'] = "3E92" + clusters.at[172, 'PDB code'] = "3E93" + + with open(r'./points_list_test_reg.txt', "r") as fid: + pred_list = [] + for line in fid.readlines(): + pred_list.append(line.strip('\n')) + + #pred_values = [] + #for i in clusters["PDB code"]: # loops through each protein for the respective PDB codes + # for j,item in enumerate(pred_list): # loops through each line of the prediction value text file + # item = item.upper() # changes each line to uppercase because the txt file PDB codes are lowercase and we need them in uppercase + # if item[18:22] == i: # j[18:22] is the PDB code for the pred value. This matches the PDB codes of the prediction and true values of proteins + # x = item[44:] # j[44:] is the prediction value + # x = float(x) # turns predicion value from string to float + # x = all_probs[j] + # pred_values.append(x) + #clusters["pred"] = pred_values # adds a column in the cluster dataframe for predicted values + #print(clusters) + + #corr = clusters.groupby('Cluster ID')[['Binding constant','pred']].corr().iloc[0::2,-1] + import matplotlib.pyplot as plt + import seaborn as sns + #plt.figure() + #sns.distplot(corr, kde=False) + #plt.xlabel('Correlation') + #plt.ylabel('Count') + #plt.savefig('./cluster.png') + + #mean_df = clusters.groupby('Cluster ID').mean() + #plt.figure() + #sns.regplot(mean_df['Binding constant'], mean_df['pred']) + #plt.xlabel('Cluster Mean Label') + #plt.ylabel('Cluster Mean Pred') + #plt.savefig('./cluster_mean.png') + #print('Inter cluster corr: {}'.format(np.corrcoef(mean_df['Binding constant'], mean_df['pred'])[0,1])) + + print("Double Verify") + cluster_list = [] + id_list = [] + clusters = clusters.set_index('PDB code') + for j, item in enumerate(pred_list): + item = item.upper() + id = item[18:22] + cluster_list.append(clusters.loc[id, 'Cluster ID']) + id_list.append(id) + print(id, all_labels[j], all_probs[j], clusters.loc[id, 'Binding constant']) + + new_df = pd.DataFrame({"pred": all_probs, "label": all_labels, "cluster": cluster_list, "id": id_list}) + corr = new_df.groupby('cluster')[['label', 'pred']].corr().iloc[0::2, -1] + plt.figure() + sns.distplot(corr, kde=False) + plt.xlabel('Correlation') + plt.ylabel('Count') + plt.savefig('./cluster.png') + print('Corr: {}'.format(list(np.array(corr)))) + #print(new_df) + new_df.to_csv('result.csv') + + mean_df = new_df.groupby('cluster').mean() + plt.figure() + sns.regplot(mean_df['label'], mean_df['pred']) + plt.xlabel('Cluster Mean Label') + plt.ylabel('Cluster Mean Pred') + plt.savefig('./cluster_mean.png') + print('Inter cluster corr: {}'.format(np.corrcoef(mean_df['label'], mean_df['pred'])[0,1])) + + print("<0: ", (corr<0).sum()) + print(">0.8: ", (corr>=0.8).sum()) + print(">0.9: ", (corr>=0.9).sum()) + print('min; ', corr.min()) + report_reg_metrics(all_labels_mean, all_preds_mean) report_scatter(all_labels_mean, all_preds_mean) + report_cluster_corr(all_labels_mean, all_preds_mean) def check_grids(self, test_size = 285): # build graph @@ -341,3 +430,4 @@ class TFSolver: def run(self): eval('self.{}()'.format(self.flags.run)) +