Skip to content

Commit

Permalink
add Classification Report and confusion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Qinqing Liu committed Apr 7, 2020
1 parent 99c4a9e commit 824c640
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 0 deletions.
Binary file added __pycache__/utils.cpython-37.pyc
Binary file not shown.
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sklearn.naive_bayes import CategoricalNB
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from utils import *

K = 5
random_seed = 0
Expand Down Expand Up @@ -32,6 +33,7 @@ def read_data():
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print('accuracy is: {}'.format(acc))
save_cm_figs(y_test, y_pred, 'NB_ori') # confusion matrix

# X_2 = X[Y==2] #benign
# X_4 = X[Y==4] #cancer
Expand Down
Binary file added result/conf_no_norm_NB_ori.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/conf_norm_NB_ori.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
63 changes: 63 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import itertools

def plot_confusion_matrix(y_true, y_pred, classes = ['benign', 'cancer'],
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
print('Confusion Matrix')
cm = confusion_matrix(y_true, y_pred)
print('Classification Report')
target_names = ['benigh', 'cancer']
print(classification_report(y_true, y_pred, target_names=target_names))

if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')

print(cm)

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()


def save_cm_figs(y_true, y_pred, arc, target_names = ['begign', 'cancer']):
# Compute confusion matrix

np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(y_true, y_pred, classes=target_names,
title='Confusion matrix, without normalization')
plt.savefig('./result/conf_no_norm_{}.png'.format(arc), bbox_inches='tight')
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(y_true, y_pred, classes=target_names, normalize=True,
title='Normalized confusion matrix')
plt.savefig('./result/conf_norm_{}.png'.format(arc), bbox_inches='tight')
plt.show()

0 comments on commit 824c640

Please sign in to comment.