Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Did a re-factor and re-naming
  • Loading branch information
RussellBentley committed Apr 13, 2024
1 parent 61c3c25 commit c5185c5
Show file tree
Hide file tree
Showing 13 changed files with 552 additions and 595 deletions.
1 change: 1 addition & 0 deletions .gitignore
@@ -1 +1,2 @@
*.ipynb_checkpoints
__pycache__
487 changes: 0 additions & 487 deletions Challenge_01.ipynb

This file was deleted.

138 changes: 30 additions & 108 deletions Imbalanced_01.ipynb
Expand Up @@ -2,14 +2,16 @@
"cells": [
{
"cell_type": "code",
"execution_count": 222,
"execution_count": 1,
"id": "bbb5053b-e99c-41e7-b818-7d7efc8a434f",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import math\n",
"\n",
"import fp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.colors import ListedColormap\n",
"from matplotlib import ticker\n",
Expand All @@ -27,103 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 223,
"id": "0076cd82-08f1-42ef-99c9-d3b0ec52e9f5",
"metadata": {},
"outputs": [],
"source": [
"# Returns the aspect ratio of a given matplotlib axis\n",
"# Useful for adjusting the aspect ratio of our plots\n",
"# https://stackoverflow.com/questions/41597177/get-aspect-ratio-of-axes\n",
"from operator import sub\n",
"def get_aspect(ax):\n",
" # Total figure size\n",
" figW, figH = ax.get_figure().get_size_inches()\n",
" # Axis size on figure\n",
" _, _, w, h = ax.get_position().bounds\n",
" # Ratio of display units\n",
" disp_ratio = (figH * h) / (figW * w)\n",
" # Ratio of data units\n",
" # Negative over negative because of the order of subtraction\n",
" data_ratio = sub(*ax.get_ylim()) / sub(*ax.get_xlim())\n",
"\n",
" return disp_ratio / data_ratio"
]
},
{
"cell_type": "code",
"execution_count": 330,
"id": "7e1a9beb-0631-40dc-bdee-abb45fa4c06e",
"metadata": {},
"outputs": [],
"source": [
"# Compute the Perfomance Metrics\n",
"# - A trained classifier\n",
"# - x, y test data\n",
"# returns labels, \n",
"# 1 -> true positive\n",
"# 2 -> true negative\n",
"# 3 -> false positive\n",
"# 4 -> false negative\n",
"def score_classifier(clf, x, y):\n",
" score = clf.score(x, y)\n",
" (m, _) = x.shape\n",
" p = clf.predict(x)\n",
" true_pos = 0\n",
" true_neg = 0\n",
" false_pos = 0\n",
" false_neg = 0\n",
" labels = []\n",
" for i in range(0, m):\n",
" x_i = x[i]\n",
" y_i = y[i]\n",
" y_h = p[i]\n",
"\n",
" \n",
" if y_i == y_h and y_h == 1:\n",
" true_pos += 1\n",
" labels.append(1)\n",
" \n",
" if y_i == y_h and y_h == 0:\n",
" true_neg += 1\n",
" labels.append(2)\n",
"\n",
" if y_i != y_h and y_h == 1:\n",
" false_pos += 1\n",
" labels.append(3)\n",
"\n",
" if y_i != y_h and y_h == 0:\n",
" false_neg += 1\n",
" labels.append(4)\n",
"\n",
" # Metrics from Table 2 of paper\n",
" tnr = true_neg / (true_neg + false_pos)\n",
" tpr = true_pos / (true_pos + false_neg)\n",
" g_mean = math.sqrt(tnr * tpr)\n",
" precision = true_pos / (true_pos + false_pos)\n",
" f_measure = (2 * precision * tpr) / (precision + tpr)\n",
" \n",
" print(f\"TP: {true_pos}, TN: {true_neg}, FP: {false_pos}, FN: {false_neg}\")\n",
" print(f\"tnr: {tnr}, tpr: {tpr}, g_mean: {g_mean}, precision: {precision}, f_measure: {f_measure}, score: {score}\")\n",
"\n",
" return labels\n",
"\n",
"# Returns four datasets\n",
"# filtered by label type.\n",
"# useful for labeling scatter plots\n",
"def filter_points(x, labels):\n",
" x_tp = x[[i for i, l in enumerate(labels) if l == 1],:]\n",
" x_tn = x[[i for i, l in enumerate(labels) if l == 2],:]\n",
" x_fp = x[[i for i, l in enumerate(labels) if l == 3],:]\n",
" x_fn = x[[i for i, l in enumerate(labels) if l == 4],:]\n",
" print(f\"tp: {x_tp.shape}, tn: {x_tn.shape}, fp: {x_fp.shape}, fn: {x_fn.shape}\")\n",
" return (x_tp, x_tn, x_fp, x_fn)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 331,
"execution_count": 2,
"id": "587b3456-e7ff-4c6d-9078-ae0f48528c84",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -173,7 +79,7 @@
},
{
"cell_type": "code",
"execution_count": 333,
"execution_count": 4,
"id": "af6cbcfd-9f04-4936-a14c-dc2057e76774",
"metadata": {},
"outputs": [
Expand All @@ -199,10 +105,10 @@
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x29c0765c0>"
"<matplotlib.colorbar.Colorbar at 0x15fcac6a0>"
]
},
"execution_count": 333,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
Expand Down Expand Up @@ -243,19 +149,19 @@
"clf1.fit(x_train, y_train)\n",
"score = clf1.score(x_test, y_test)\n",
"print(\"random forest\")\n",
"rf_labels = score_classifier(clf1, x_test, y_test)\n",
"(_, rf_labels, _) = fp.score_classifier(clf1, x_test, y_test)\n",
"\n",
"clf2 = RandomForestClassifier(class_weight = {0:0.05, 1:0.95}, **common_params)\n",
"clf2.fit(x_train, y_train)\n",
"score = clf2.score(x_test, y_test)\n",
"print(\"weighted random forest\")\n",
"wf_labels = score_classifier(clf2, x_test, y_test)\n",
"(_, wf_labels, _) = fp.score_classifier(clf2, x_test, y_test)\n",
"\n",
"clf3 = BalancedRandomForestClassifier(sampling_strategy='all', replacement=True, bootstrap=False, **common_params)\n",
"clf3.fit(x_train, y_train)\n",
"score = clf3.score(x_test, y_test)\n",
"print(\"balanced forest\")\n",
"bf_labels = score_classifier(clf3, x_test, y_test)\n",
"(_, bf_labels, _) = fp.score_classifier(clf3, x_test, y_test)\n",
"\n",
"# For plotting the decision boundary\n",
"cm = plt.cm.RdBu\n",
Expand Down Expand Up @@ -290,7 +196,7 @@
"\n",
"rf_points_ax = fig.add_subplot(gs[0, 0])\n",
"rf_boundary_ax = fig.add_subplot(gs[1, 0])\n",
"print(f\"Zoom plot aspect ratio (Adjust height_ratio until this is ~1): {get_aspect(rf_points_ax)}\")\n",
"print(f\"Zoom plot aspect ratio (Adjust height_ratio until this is ~1): {fp.get_aspect_ratio(rf_points_ax)}\")\n",
"rf_points_ax.xaxis.set_major_locator(ticker.NullLocator()) \n",
"rf_boundary_ax.xaxis.set_major_locator(ticker.NullLocator())\n",
"rf_points_ax.yaxis.set_major_locator(ticker.NullLocator()) \n",
Expand All @@ -301,7 +207,7 @@
" ax = rf_boundary_ax, \n",
" **common_boundary_args\n",
")\n",
"(x_tp, x_tn, x_fp, x_fn) = filter_points(x_test, rf_labels)\n",
"(x_tp, x_tn, x_fp, x_fn) = fp.filter_points(x_test, rf_labels)\n",
"rf_points_ax.scatter(x_tp[:,0], x_tp[:,1], c = tp_c, label='True Positive', **common_scatter_args)\n",
"rf_points_ax.scatter(x_tn[:,0], x_tn[:,1], c = tn_c, label='True Negative', **common_scatter_args)\n",
"rf_points_ax.scatter(x_fp[:,0], x_fp[:,1], c = fp_c, label='False Positive', **common_scatter_args)\n",
Expand All @@ -320,7 +226,7 @@
" ax = wf_boundary_ax,\n",
" **common_boundary_args\n",
")\n",
"(x_tp, x_tn, x_fp, x_fn) = filter_points(x_test, wf_labels)\n",
"(x_tp, x_tn, x_fp, x_fn) = fp.filter_points(x_test, wf_labels)\n",
"wf_points_ax.scatter(x_tp[:,0], x_tp[:,1], c = tp_c, label='True Positive', **common_scatter_args)\n",
"wf_points_ax.scatter(x_tn[:,0], x_tn[:,1], c = tn_c, label='True Negative', **common_scatter_args)\n",
"wf_points_ax.scatter(x_fp[:,0], x_fp[:,1], c = fp_c, label='False Positive', **common_scatter_args)\n",
Expand All @@ -339,7 +245,7 @@
" ax = bf_boundary_ax,\n",
" **common_boundary_args\n",
")\n",
"(x_tp, x_tn, x_fp, x_fn) = filter_points(x_test, bf_labels)\n",
"(x_tp, x_tn, x_fp, x_fn) = fp.filter_points(x_test, bf_labels)\n",
"bf_points_ax.scatter(x_tp[:,0], x_tp[:,1], c = tp_c, label='True Positive', **common_scatter_args)\n",
"bf_points_ax.scatter(x_tn[:,0], x_tn[:,1], c = tn_c, label='True Negative', **common_scatter_args)\n",
"bf_points_ax.scatter(x_fp[:,0], x_fp[:,1], c = fp_c, label='False Positive', **common_scatter_args)\n",
Expand All @@ -361,6 +267,22 @@
" label='Decision Probability',\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b319a0d-cf8b-4ed5-8d20-47eb24934b81",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "51c3e353-05a2-48bd-be13-3f945307579c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit c5185c5

Please sign in to comment.