From bd2bc7f769546a4e1263a496380e91170b071279 Mon Sep 17 00:00:00 2001 From: Qinqing Liu Date: Mon, 6 Apr 2020 20:58:52 -0400 Subject: [PATCH] add correlation analysis --- __pycache__/utils.cpython-37.pyc | Bin 2148 -> 2711 bytes main.py | 25 ++++++++++--------------- utils.py | 16 +++++++++++++++- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/__pycache__/utils.cpython-37.pyc b/__pycache__/utils.cpython-37.pyc index ce006f9e9a316d4fd76b984ca864565a057e90e1..d3f7261f159e91121377e08b6ff7547ecda2ad69 100644 GIT binary patch delta 1051 zcmZ8gOKTKC5bo-E?q+whNa8aQd|+5GYl5PnA!^Jacra)VhJ|Dqr)SrR^RjzZjm)x# zBwjp7SV8c#0WTg!!JB`;gEx^rd&ytOL9L#shz<2okFM&ktGd5eUyWAg%H<-#6?}QV z`m(rPISlu@ROwWuK#Y+@#Jw0AOcaM9HNGDwrz*cWab2|ubcCv_z7 zJ-FJS)6;u!)&;zoS`iGIh5W7wwW>NXGb62t0{fVn) zNj*-yNETE=;76X27Wb0CaXo2>B=A#dItlJzl17O%9L}Bfn%S^b?WWW6nKb^UxiM!# z>dKa|EMfgORmFD*VgHnzmKwJ>ia?+QHt0}9EFm6*!-ykL1%3>p?>;Wl@iA1e5vA1O zZ7*#`PKY#(RuDTWFTX#f6~W-Dh5c#$)9@53{3L>;Zpg>`KlGCm|MM74f;esBkXjw# z$5AuHLjIG{{%hk3EcCw`-^}Z(%2fY@xjdAsbNluo>qJkTB&!06BpgQ*Cn=T?ysq44 z%WrqM#~09OSi*Vap@Yw;kS^}Wk83)=iUz-?;+n%!=Acw_+z;k$14?VnhNG1szZ1WVx#N509KtPR5=IW7Y}0S@med0R%O$D%jK>Hh~&Ip?c;QEz<;8 delta 472 zcmYjLJ4*vW5YBAwvAJCI1Rp^VABA!n6zKNu!3+LR$0-u*H zMVy5fw!&>Bd{OrE4sUQ(()B;sz$`VTu-@?NC>E<&5u*sXyqTv!cY(qvpb2f zDIM{hIP${SYoD$}J(;E2AP$VFLD!FZd_0*krVnRP! Ye#p9>y9X!{7J`YlEucgLg5R3|1shs$yZ`_I diff --git a/main.py b/main.py index 1dcb57d..69b7387 100644 --- a/main.py +++ b/main.py @@ -4,29 +4,22 @@ from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from utils import * +import statsmodels.graphics.api as smg K = 5 random_seed = 0 test_ratio = 0.2 -def read_data(): - with open('./breast-cancer-wisconsin.data', 'r') as data_fid: - lines = data_fid.readlines() - records = [] - for line in lines: - if '?' in line: - line = line.replace('?', '11') - line = line.split(',') - line = [int(item) for item in line][1:] - records.append(line) - records = np.array(records) - X = records[:, :-1] - y = records[:, -1] - return X, y - X, y = read_data() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio, random_state=random_seed) +# Correlation Analysis +corr_matrix = np.corrcoef(X_train.T) +smg.plot_corr(corr_matrix, xnames = ["Clump_Thickness", "Cell_Size", "Cell_Shape", + "Marginal_Adhesion", "Single_Epithelial_Cell_Size", "Bare_Nuclei", "Bland_Chromatin", + "Normal_Nucleoli", "Mitoses"]) +plt.show() + # Naive Bayes clf = CategoricalNB() clf.fit(X_train, y_train) @@ -35,6 +28,8 @@ def read_data(): print('accuracy is: {}'.format(acc)) save_cm_figs(y_test, y_pred, 'NB_ori') # confusion matrix + + # X_2 = X[Y==2] #benign # X_4 = X[Y==4] #cancer diff --git a/utils.py b/utils.py index 99df8bc..5d21c5e 100644 --- a/utils.py +++ b/utils.py @@ -3,6 +3,21 @@ from sklearn.metrics import classification_report, confusion_matrix import itertools +def read_data(): + with open('./breast-cancer-wisconsin.data', 'r') as data_fid: + lines = data_fid.readlines() + records = [] + for line in lines: + if '?' in line: + line = line.replace('?', '11') + line = line.split(',') + line = [int(item) for item in line][1:] + records.append(line) + records = np.array(records) + X = records[:, :-1] + y = records[:, -1] + return X, y + def plot_confusion_matrix(y_true, y_pred, classes = ['benign', 'cancer'], normalize=False, title='Confusion matrix', @@ -60,4 +75,3 @@ def save_cm_figs(y_true, y_pred, arc, target_names = ['begign', 'cancer']): title='Normalized confusion matrix') plt.savefig('./result/conf_norm_{}.png'.format(arc), bbox_inches='tight') plt.show() -