Permalink
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
cse5819-FinalProject/fp.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
166 lines (143 sloc)
4.22 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
fp.py | |
A module of helper functions for our final project. | |
''' | |
import numpy as np | |
import math | |
# get_aspect_ratio(ax) | |
# | |
# Returns the aspect ratio of a given matplotlib axis (ax). | |
# Useful for adjusting the aspect ratio of our plots | |
# | |
# Source | |
# https://stackoverflow.com/questions/41597177/get-aspect-ratio-of-axes | |
from operator import sub | |
def get_aspect_ratio(ax): | |
# Total figure size | |
figW, figH = ax.get_figure().get_size_inches() | |
# Axis size on figure | |
_, _, w, h = ax.get_position().bounds | |
# Ratio of display units | |
disp_ratio = (figH * h) / (figW * w) | |
# Ratio of data units | |
# Negative over negative because of the order of subtraction | |
data_ratio = sub(*ax.get_ylim()) / sub(*ax.get_xlim()) | |
return disp_ratio / data_ratio | |
# score_classifier(clf, x, y) | |
# | |
# For a classifier (clf), and data x, y | |
# compute various performance metrics. | |
# | |
# Returns (p, l, m) where: | |
# p -> predicted class | |
# l -> prediction label (true pos, ...) | |
# m -> dict of metrics | |
# | |
# Prediction labels: | |
# 1 -> true positive | |
# 2 -> true negative | |
# 3 -> false positive | |
# 4 -> false negative | |
def score_classifier(clf, x, y, cutoff = None): | |
score = clf.score(x, y) | |
(m, _) = x.shape | |
if cutoff is None: | |
p = clf.predict(x) | |
else: | |
probs = clf.predict_proba(x) | |
cutoff_array = np.full_like(probs, cutoff) | |
p = np.greater(probs, cutoff_array)[:,1].astype(int) | |
true_pos = 0 | |
true_neg = 0 | |
false_pos = 0 | |
false_neg = 0 | |
labels = [] | |
for i in range(0, m): | |
x_i = x[i] | |
y_i = y[i] | |
y_h = p[i] | |
if y_i == y_h and y_h == 1: | |
true_pos += 1 | |
labels.append(1) | |
if y_i == y_h and y_h == 0: | |
true_neg += 1 | |
labels.append(2) | |
if y_i != y_h and y_h == 1: | |
false_pos += 1 | |
labels.append(3) | |
if y_i != y_h and y_h == 0: | |
false_neg += 1 | |
labels.append(4) | |
# Metrics from Table 2 of paper | |
tnr = 0 | |
if true_neg + false_pos != 0: | |
tnr = true_neg / (true_neg + false_pos) | |
tpr = 0 | |
if true_pos + false_neg != 0: | |
tpr = true_pos / (true_pos + false_neg) | |
fpr = 0 | |
if false_pos + true_neg != 0: | |
fpr = false_pos / (false_pos + true_neg) | |
g_mean = math.sqrt(tnr * tpr) | |
precision = 0 | |
if true_pos + false_pos != 0: | |
precision = true_pos / (true_pos + false_pos) | |
f_measure = 0 | |
if precision + tpr != 0: | |
f_measure = (2 * precision * tpr) / (precision + tpr) | |
metrics = { | |
"tp": true_pos, | |
"tn": true_neg, | |
"fp": false_pos, | |
"fn": false_neg, | |
"tnr": tnr, | |
"tpr": tpr, | |
"fpr": fpr, | |
"g_mean": g_mean, | |
"precision": precision, | |
"f_measure": f_measure, | |
"score": score | |
} | |
return (p, labels, metrics) | |
# metric_stats(ms) | |
# | |
# Compute metric statistics for a list of metric dicts, | |
# as computed by score_classifier | |
# | |
# Returns a metric stats object | |
# that has the same keys, by now each value is a tuple | |
# (mean, std) | |
def metric_stats(ms): | |
n_m = float(len(ms)) | |
result = {} | |
for k in ms[0].keys(): | |
values = [] | |
for m in ms: | |
values.append(m[k]) | |
mean = np.mean(values) | |
std = np.std(values) | |
result[k] = (mean, std) | |
return result | |
# filter_points(x, l) | |
# | |
# Filter data points x by their prediction label | |
# (from score_classifier). | |
# Returns four datasets. | |
# | |
# Useful for making labeled scatter plots | |
def filter_points(x, labels): | |
x_tp = x[[i for i, l in enumerate(labels) if l == 1],:] | |
x_tn = x[[i for i, l in enumerate(labels) if l == 2],:] | |
x_fp = x[[i for i, l in enumerate(labels) if l == 3],:] | |
x_fn = x[[i for i, l in enumerate(labels) if l == 4],:] | |
print(f"tp: {x_tp.shape}, tn: {x_tn.shape}, fp: {x_fp.shape}, fn: {x_fn.shape}") | |
return (x_tp, x_tn, x_fp, x_fn) | |
# find_best_index(measurements, metric_name) | |
# | |
# Find the index of the metric stat dict in with the best | |
# mean value for metric_name. | |
def find_best_index(metric_stat_list, metric_name): | |
m_metrics = len(metric_stat_list) | |
indices = [i for i in range(0, m_metrics)] | |
indices.sort(key=lambda i: metric_stat_list[i][metric_name][0], reverse=True) | |
return indices[0] | |