Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
add correlation analysis
  • Loading branch information
Qinqing Liu committed Apr 7, 2020
1 parent 824c640 commit bd2bc7f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
Binary file modified __pycache__/utils.cpython-37.pyc
Binary file not shown.
25 changes: 10 additions & 15 deletions main.py
Expand Up @@ -4,29 +4,22 @@ from sklearn.naive_bayes import CategoricalNB
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from utils import *
import statsmodels.graphics.api as smg

K = 5
random_seed = 0
test_ratio = 0.2

def read_data():
with open('./breast-cancer-wisconsin.data', 'r') as data_fid:
lines = data_fid.readlines()
records = []
for line in lines:
if '?' in line:
line = line.replace('?', '11')
line = line.split(',')
line = [int(item) for item in line][1:]
records.append(line)
records = np.array(records)
X = records[:, :-1]
y = records[:, -1]
return X, y

X, y = read_data()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio, random_state=random_seed)

# Correlation Analysis
corr_matrix = np.corrcoef(X_train.T)
smg.plot_corr(corr_matrix, xnames = ["Clump_Thickness", "Cell_Size", "Cell_Shape",
"Marginal_Adhesion", "Single_Epithelial_Cell_Size", "Bare_Nuclei", "Bland_Chromatin",
"Normal_Nucleoli", "Mitoses"])
plt.show()

# Naive Bayes
clf = CategoricalNB()
clf.fit(X_train, y_train)
Expand All @@ -35,6 +28,8 @@ acc = accuracy_score(y_test, y_pred)
print('accuracy is: {}'.format(acc))
save_cm_figs(y_test, y_pred, 'NB_ori') # confusion matrix



# X_2 = X[Y==2] #benign
# X_4 = X[Y==4] #cancer

Expand Down
16 changes: 15 additions & 1 deletion utils.py
Expand Up @@ -3,6 +3,21 @@ import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import itertools

def read_data():
with open('./breast-cancer-wisconsin.data', 'r') as data_fid:
lines = data_fid.readlines()
records = []
for line in lines:
if '?' in line:
line = line.replace('?', '11')
line = line.split(',')
line = [int(item) for item in line][1:]
records.append(line)
records = np.array(records)
X = records[:, :-1]
y = records[:, -1]
return X, y

def plot_confusion_matrix(y_true, y_pred, classes = ['benign', 'cancer'],
normalize=False,
title='Confusion matrix',
Expand Down Expand Up @@ -60,4 +75,3 @@ def save_cm_figs(y_true, y_pred, arc, target_names = ['begign', 'cancer']):
title='Normalized confusion matrix')
plt.savefig('./result/conf_norm_{}.png'.format(arc), bbox_inches='tight')
plt.show()

0 comments on commit bd2bc7f

Please sign in to comment.