diff --git a/__pycache__/utils.cpython-37.pyc b/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000..ce006f9 Binary files /dev/null and b/__pycache__/utils.cpython-37.pyc differ diff --git a/main.py b/main.py index d6b5a05..1dcb57d 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ from sklearn.naive_bayes import CategoricalNB from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score +from utils import * K = 5 random_seed = 0 @@ -32,6 +33,7 @@ def read_data(): y_pred = clf.predict(X_test) acc = accuracy_score(y_test, y_pred) print('accuracy is: {}'.format(acc)) +save_cm_figs(y_test, y_pred, 'NB_ori') # confusion matrix # X_2 = X[Y==2] #benign # X_4 = X[Y==4] #cancer diff --git a/result/conf_no_norm_NB_ori.png b/result/conf_no_norm_NB_ori.png new file mode 100644 index 0000000..1be0c80 Binary files /dev/null and b/result/conf_no_norm_NB_ori.png differ diff --git a/result/conf_norm_NB_ori.png b/result/conf_norm_NB_ori.png new file mode 100644 index 0000000..8f54c1c Binary files /dev/null and b/result/conf_norm_NB_ori.png differ diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..99df8bc --- /dev/null +++ b/utils.py @@ -0,0 +1,63 @@ +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import classification_report, confusion_matrix +import itertools + +def plot_confusion_matrix(y_true, y_pred, classes = ['benign', 'cancer'], + normalize=False, + title='Confusion matrix', + cmap=plt.cm.Blues): + """ + This function prints and plots the confusion matrix. + Normalization can be applied by setting `normalize=True`. + """ + print('Confusion Matrix') + cm = confusion_matrix(y_true, y_pred) + print('Classification Report') + target_names = ['benigh', 'cancer'] + print(classification_report(y_true, y_pred, target_names=target_names)) + + if normalize: + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + print("Normalized confusion matrix") + else: + print('Confusion matrix, without normalization') + + print(cm) + + plt.imshow(cm, interpolation='nearest', cmap=cmap) + plt.title(title) + plt.colorbar() + tick_marks = np.arange(len(classes)) + plt.xticks(tick_marks, classes, rotation=45) + plt.yticks(tick_marks, classes) + + fmt = '.2f' if normalize else 'd' + thresh = cm.max() / 2. + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + plt.text(j, i, format(cm[i, j], fmt), + horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black") + + plt.ylabel('True label') + plt.xlabel('Predicted label') + plt.tight_layout() + + +def save_cm_figs(y_true, y_pred, arc, target_names = ['begign', 'cancer']): + # Compute confusion matrix + + np.set_printoptions(precision=2) + + # Plot non-normalized confusion matrix + plt.figure() + plot_confusion_matrix(y_true, y_pred, classes=target_names, + title='Confusion matrix, without normalization') + plt.savefig('./result/conf_no_norm_{}.png'.format(arc), bbox_inches='tight') + # Plot normalized confusion matrix + plt.figure() + plot_confusion_matrix(y_true, y_pred, classes=target_names, normalize=True, + title='Normalized confusion matrix') + plt.savefig('./result/conf_norm_{}.png'.format(arc), bbox_inches='tight') + plt.show() +