Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
RussellBentley I'm done
Latest commit b706129 Apr 26, 2024 History
0 contributors

Users who have contributed to this file

'''
fp.py
A module of helper functions for our final project.
'''
import numpy as np
import math
# get_aspect_ratio(ax)
#
# Returns the aspect ratio of a given matplotlib axis (ax).
# Useful for adjusting the aspect ratio of our plots
#
# Source
# https://stackoverflow.com/questions/41597177/get-aspect-ratio-of-axes
from operator import sub
def get_aspect_ratio(ax):
# Total figure size
figW, figH = ax.get_figure().get_size_inches()
# Axis size on figure
_, _, w, h = ax.get_position().bounds
# Ratio of display units
disp_ratio = (figH * h) / (figW * w)
# Ratio of data units
# Negative over negative because of the order of subtraction
data_ratio = sub(*ax.get_ylim()) / sub(*ax.get_xlim())
return disp_ratio / data_ratio
# score_classifier(clf, x, y)
#
# For a classifier (clf), and data x, y
# compute various performance metrics.
#
# Returns (p, l, m) where:
# p -> predicted class
# l -> prediction label (true pos, ...)
# m -> dict of metrics
#
# Prediction labels:
# 1 -> true positive
# 2 -> true negative
# 3 -> false positive
# 4 -> false negative
def score_classifier(clf, x, y, cutoff = None):
score = clf.score(x, y)
(m, _) = x.shape
if cutoff is None:
p = clf.predict(x)
else:
probs = clf.predict_proba(x)
cutoff_array = np.full_like(probs, cutoff)
p = np.greater(probs, cutoff_array)[:,1].astype(int)
true_pos = 0
true_neg = 0
false_pos = 0
false_neg = 0
labels = []
for i in range(0, m):
x_i = x[i]
y_i = y[i]
y_h = p[i]
if y_i == y_h and y_h == 1:
true_pos += 1
labels.append(1)
if y_i == y_h and y_h == 0:
true_neg += 1
labels.append(2)
if y_i != y_h and y_h == 1:
false_pos += 1
labels.append(3)
if y_i != y_h and y_h == 0:
false_neg += 1
labels.append(4)
# Metrics from Table 2 of paper
tnr = 0
if true_neg + false_pos != 0:
tnr = true_neg / (true_neg + false_pos)
tpr = 0
if true_pos + false_neg != 0:
tpr = true_pos / (true_pos + false_neg)
fpr = 0
if false_pos + true_neg != 0:
fpr = false_pos / (false_pos + true_neg)
g_mean = math.sqrt(tnr * tpr)
precision = 0
if true_pos + false_pos != 0:
precision = true_pos / (true_pos + false_pos)
f_measure = 0
if precision + tpr != 0:
f_measure = (2 * precision * tpr) / (precision + tpr)
metrics = {
"tp": true_pos,
"tn": true_neg,
"fp": false_pos,
"fn": false_neg,
"tnr": tnr,
"tpr": tpr,
"fpr": fpr,
"g_mean": g_mean,
"precision": precision,
"f_measure": f_measure,
"score": score
}
return (p, labels, metrics)
# metric_stats(ms)
#
# Compute metric statistics for a list of metric dicts,
# as computed by score_classifier
#
# Returns a metric stats object
# that has the same keys, by now each value is a tuple
# (mean, std)
def metric_stats(ms):
n_m = float(len(ms))
result = {}
for k in ms[0].keys():
values = []
for m in ms:
values.append(m[k])
mean = np.mean(values)
std = np.std(values)
result[k] = (mean, std)
return result
# filter_points(x, l)
#
# Filter data points x by their prediction label
# (from score_classifier).
# Returns four datasets.
#
# Useful for making labeled scatter plots
def filter_points(x, labels):
x_tp = x[[i for i, l in enumerate(labels) if l == 1],:]
x_tn = x[[i for i, l in enumerate(labels) if l == 2],:]
x_fp = x[[i for i, l in enumerate(labels) if l == 3],:]
x_fn = x[[i for i, l in enumerate(labels) if l == 4],:]
print(f"tp: {x_tp.shape}, tn: {x_tn.shape}, fp: {x_fp.shape}, fn: {x_fn.shape}")
return (x_tp, x_tn, x_fp, x_fn)
# find_best_index(measurements, metric_name)
#
# Find the index of the metric stat dict in with the best
# mean value for metric_name.
def find_best_index(metric_stat_list, metric_name):
m_metrics = len(metric_stat_list)
indices = [i for i in range(0, m_metrics)]
indices.sort(key=lambda i: metric_stat_list[i][metric_name][0], reverse=True)
return indices[0]