From 4f589f1a6deecf6d7c00fe669ba340aba66e6dcc Mon Sep 17 00:00:00 2001 From: Luis Roberto Mercado Diaz Date: Tue, 2 Apr 2024 17:08:08 -0400 Subject: [PATCH] Changes on minimal errors on plotting --- .../active_learning/ss_active_learning.py | 2 +- BML_project/main_checkpoints_updated.py | 64 ++++++++++++++++++ BML_project/models/ss_gp_model_checkpoint.py | 51 ++++++++++---- BML_project/ss_main.py | 4 +- .../__pycache__/data_loader.cpython-311.pyc | Bin 17950 -> 19704 bytes BML_project/utils_gp/data_loader.py | 28 +++----- 6 files changed, 114 insertions(+), 35 deletions(-) create mode 100644 BML_project/main_checkpoints_updated.py diff --git a/BML_project/active_learning/ss_active_learning.py b/BML_project/active_learning/ss_active_learning.py index 4c44836..d75b67f 100644 --- a/BML_project/active_learning/ss_active_learning.py +++ b/BML_project/active_learning/ss_active_learning.py @@ -80,7 +80,7 @@ def stochastic_compare_kmeans_gp_predictions(kmeans_model, gp_model, data_loader kmeans_predictions = kmeans_model.predict(data.cpu().numpy()) all_labels.append(labels.cpu().numpy()) all_data.append((gp_predictions, kmeans_predictions)) - + print(f"Processed batch size: {len(current_batch_labels)}, Cumulative original_labels size: {len(original_labels)}, Cumulative gp_predictions size: {len(gp_predictions)}") return all_data, np.concatenate(all_labels) import random diff --git a/BML_project/main_checkpoints_updated.py b/BML_project/main_checkpoints_updated.py new file mode 100644 index 0000000..234291c --- /dev/null +++ b/BML_project/main_checkpoints_updated.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Feb 7 15:34:31 2024 + +@author: lrm22005 +""" +import os +import tqdm +import torch +from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples +from models.ss_gp_model import MultitaskGPModel, train_gp_model +from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data +from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions, label_samples +from utils.visualization import plot_comparative_results, plot_training_performance, plot_results + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def main(): + n_classes = 4 + batch_size = 1024 + clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids() + data_format = 'pt' + + train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size) + + # Initialize result storage + results = { + 'train_loss': [], + 'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []}, + 'active_learning': {'validation_metrics': []}, # Store validation metrics for each active learning iteration + 'test_metrics': None + } + + # Initial model training + model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_full.pt') + + # Save initial training metrics + results['train_loss'].extend(training_metrics['train_loss']) + for metric in ['precision', 'recall', 'f1_score']: + results['validation_metrics'][metric].extend(training_metrics[metric]) + + active_learning_iterations = 10 + for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration'): + uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, device=device) + train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size) + + # Re-train the model with updated training data + model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt') + + # Store validation metrics for each active learning iteration + results['active_learning']['validation_metrics'].append(val_metrics) + + # Final evaluations + test_metrics = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes) + results['test_metrics'] = test_metrics + + # Visualization of results + plot_training_performance(results['train_loss'], results['validation_metrics']) + plot_results(results['test_metrics']) # Adjust this function to handle the structure of test_metrics + + print("Final Test Metrics:", results['test_metrics']) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/BML_project/models/ss_gp_model_checkpoint.py b/BML_project/models/ss_gp_model_checkpoint.py index 282c73d..c11f8cb 100644 --- a/BML_project/models/ss_gp_model_checkpoint.py +++ b/BML_project/models/ss_gp_model_checkpoint.py @@ -69,30 +69,40 @@ def forward(self, x): return latent_pred -def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False, batch_size=1024): +def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_checkpoint_path=None): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MultitaskGPModel().to(device) likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.1) mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) + best_val_loss = float('inf') + epochs_no_improve = 0 - # Load checkpoint if resuming training start_epoch = 0 - current_batch_index = 0 # Default value in case it's not found in the checkpoint - if resume_training and os.path.exists(checkpoint_path): - checkpoint = torch.load(checkpoint_path) + start_batch = 0 + + # Resume from checkpoint if specified + if resume_checkpoint_path is not None and os.path.exists(resume_checkpoint_path): + checkpoint = torch.load(resume_checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) likelihood.load_state_dict(checkpoint['likelihood_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch - current_batch_index = checkpoint.get('current_batch_index', 0) # Retrieve the last batch index if it exists - print(f"Resuming training from epoch {start_epoch}, batch index {current_batch_index}") - - best_val_loss = float('inf') - epochs_no_improve = 0 - metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []} - - for epoch in range(start_epoch, num_iterations): + start_epoch = checkpoint.get('epoch', 0) + start_batch = checkpoint.get('batch', 0) + print(f"Resuming training from epoch {start_epoch}, batch {start_batch}") + + metrics = { + 'precision': [], + 'recall': [], + 'f1_score': [], + 'train_loss': [] + } + + for epoch in tqdm.tqdm(range(start_epoch, num_iterations), desc='Training', unit='epoch', leave=False): for batch_index, train_batch in enumerate(train_loader): + if epoch == start_epoch and batch_index < start_batch: + continue # Skip batches until the saved batch index + model.train() likelihood.train() optimizer.zero_grad() @@ -104,6 +114,17 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat loss.backward() optimizer.step() + # Save checkpoint at intervals or based on other conditions + if (batch_index + 1) % 100 == 0: # Example condition + torch.save({ + 'epoch': epoch, + 'batch': batch_index, + 'model_state_dict': model.state_dict(), + 'likelihood_state_dict': likelihood.state_dict(), + 'optimizer_state_dict': optimizer.state_dict() + }, checkpoint_path) + print(f"Checkpoint saved at epoch {epoch}, batch {batch_index}") + # Stochastic validation model.eval() likelihood.eval() @@ -160,7 +181,7 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat print(f"Early stopping triggered at epoch {epoch+1}") break - # Optionally, load the best model at the end of training + # Ensure to load the latest model state after the loop in case of early stopping if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) diff --git a/BML_project/ss_main.py b/BML_project/ss_main.py index a610684..1b9cfce 100644 --- a/BML_project/ss_main.py +++ b/BML_project/ss_main.py @@ -6,7 +6,7 @@ """ import tqdm import torch -from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples +from utils_gp.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples from models.ss_gp_model import MultitaskGPModel, train_gp_model from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions @@ -71,6 +71,8 @@ def main(): results['test_metrics'] = test_metrics test_gp_vs_kmeans_data, test_original_labels = stochastic_compare_kmeans_gp_predictions(test_kmeans_model, model, test_loader, n_batches=5, device=device) + # Before calling confusion_matrix in plot_comparative_results function + print(f"Length of original_labels: {len(original_labels)}, Length of gp_predictions: {len(gp_predictions)}") plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels) # Visualization of results diff --git a/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc b/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc index 0b1a7ebdb4b25cf36cb492dfb1b7d20074e1bb08..ed2626e5a52ed6bc28cce7621bfb3c6c869cfc69 100644 GIT binary patch delta 3792 zcmb7GeQZt$lM?=dt~tl85Csm5pK{VTL48_0yN83K#Obxlw>=gRdxW{WGA3qb^$tMH=t84 z0(8m6fNpsiV3E9hD;F-&oE!Hhl(CWSs1i-Yl>tqNCH4h+DGw`%Nw~Vjh8+T2w{NnW zB@@LKm5b!mDmQIqPu6&TY;yStCv%|FUkqJkZYy_z^oQ+C@qW9_noy!ti44p8H8+iK zr}4y4L^h_@mdEx|8Xr|6Pe+y5P-J*ij_(gk>@VJ0J3R~|C2+nfX7jBgGVAoMU*WHr zG#zuzNsd2A6}O~{j1^qB(GTSJ#*X+98qs7uafiN zbt&hGR6`8;{;-MaYj`0Amh>3}BrI%R>|WBViOLv_4QZxCbYJ`eyd7vYcWfvg+qrve z80g;}RfaU%{f-1Kud2%6O>?7L3OqVU8-$@Jlo~%0VM6(U@}n{Itjp)t6M8WN|6+FtRyk#Z6ll z_kLk(&$j&*BI% zfC+BCr0m4tv^!JMnk{Law9N$?X587pTFjENId9c;Y1Z2a*<_#d1W)eEdg?*8e&`SV zBzP(~W54=j#=j=(UxVq!u9OAYp7N3gFAxtqoaFEEahOZVQOn`}XMd{q?HUKVAUyfI zJ}(?e0W>);H7^`VQ+_U}oLrCrv|i;4<`dbU>%A*>!a{&T7Yv-2K= zF06L+9QvO{<055vBp!j%wPP%U?P+;BDXyNMwBBU z84arr#k{^iC~mO}y#CS@Yh408@4bi*PL#+)s`_|cUd_Ogfo7=(tN$q+um}L&JCtA1 z=_crm#65Y$Q z?uJ>Z!C2k@OXoJ7&%?OLH*nB=BAv)k?&usZe~`YMavFV|0I(o`5HofXeypRdoV^kF zV*P8NflYdUJAVs0=nDwv5p=z1#PbYxoCVO#qw(*-Hm2v;N5Ln^ne?j4zw_iQ^ViIi zU$VP3%#NFWF+JEzq1EOYd_KoAp%8hF-3hHCKTC&eFN)+%mJN?o-$E@yxQzY=<#!N% zi=ZOB&7N*-5cNvQj-&J{^s+Y_`^@@0=a{4E5pt2WH-&?*gJ#^fl#g6CzMDSOe$W>(kLWv^5SzdWL11{bZi~qd7$W%G@n6 zw`qVL00ZOmC|u0cmb$JVK!;JV8A>l9oJN>NxP zIFb7;@D6ds0=E<@i1>A~z~!M?5GxnB9N381fcix+OJeH+mjjOljdQTfCO!qDEn*u! zloBzF{)=EHxP~0SBlzb4YG>;pv9W9GJY*YtyER}^t(OT~*BJAAW<{WJ~Fz4QjFS-GXvVZ=QNUB>6j0!sSV%mhk@5I#Y;jc}N~veM7) zwtLxoEB#G(q3a{K5?cV^q;r*@?3+G#wflP0TL)(cVzXoWXI%#}t^--ufwZ-4KrH?p z*lE^*-mQ_{yT+6R&9c!oj~5ke#6PlMtvM>1psxHg-Me-Nsk4CBy=DVJuQu<10M*8< zn#M&q;wHS;rFcGlJ*Cz`Tt?`=d+(PcK?|0UjU9cJtNnh#gufub#&H>EG0gJjd4n_6Ehxl`0D=i}I-(9{c-LQb?WoY{&oWn!dqZ_-# zQt0Sl>c*$mn|qDI(o6eMkF`I{_x8eN$bz9C2$yP>NJJisMIuJ!97OZqGi!G_nPPR_ zv9%aLBb3NH)z^nFKyfm<0kxYDdH_C`9vc~x_wI^!(n|JL_mx#WJgHgW9*)O1eXMIf N&DZyv*mTb`{{h2_kKq6S delta 2120 zcmZvdYfMx}6vyxE-Fx@Rf`}|GRo?2Q5k*i*s^}uDyhJOsq1a}#;VuD-Jm#*o)w_z3 zR@yYx_S6DuTa78Dh}CABrcIw}n)<=EU^Wd|OGBDy3nvk+;Mw(O$(yUsM7S)Eds&=GJjjI)u_(^LV|3Vn!>bO{ro8!2-c>xMX& zn{n9aDlzZFfFkm4=yz?%?V!|Z#4XHqRgFCO3@$=*)KwLBRrSPA%7P^h9I1Qd9?ax# zz@gOVVaq&=c?`R~f-b;YrOESy23^=nYl$}9F!zBe1jT>^7t_*27BE5Hf@HWVJM1q; z!`y*`jBv0tNeUX<=IGG5AjN4Iz-dlSO%*iiMqShm(3GD4tjwcpcpXPaOwB&c7YJ}( zi9LRZT$a{$pBfcUkvORn5nI%kc^@V8Qa2&cvd!;BjXS*of1B3MH3{8tFMXMJ7aH1j zM3*q;Obt7g!K|<|ug5-?ko2ytuR5Hd^hj51i3bb&T%)$^uq}I1kSwbPXP?r}6n?&X zbcs8>#C=I|k0|bmw5+4S!P;&gcOpOTyL&^&I z2=)8ms*=U_gEhB+^@fUbha@%t*OqL{_?n^w#l=4*`!wMUK_?8sj-?At5h>8l8=K+8 z(uH!wNczIBODphJU`8Lz&tJ@rLUq0}|1dgsI;RzX-qdNF4ne+CWT)U{p{wXTW=15f z+1KIcXK9*q2r5akX}1wsYxK4Hxh7&KJ^g-hUx2%XS?mI2EN`@mb^H~Kc@#Uq@*L$= z8cBuC-zEUzeZq0VegZ9zUx0s?XEY6<*){buX)sG^lfrZ)8Kli@Qh0=BnX;ykz@&wf z!W3kiRFCS5K&d8Gp~7CC-nEocc)?0(ZNx~MYUq?=X4 zus9>rs$21)G1hqaFbu5RRBnr%RuISX-Gs}8J%ld8Ucxxx4@AUt|7nR>$)FQLp=}s} z(xTj)KhgLZzZ%`CE@9k}+E;OG*O}5UpZsChNL}NoLk&CBP=8UKH0vA8h*;`8wcd`_ z_JGE}h5uHsPMA?hKS5RTYtlcMdMWf<@m7{PwPSRVI#ETwK%-Xx|JsZVa?lhsPt!t1 z9hh!JJp^4EHVoifc#_`3Xq%xMI9S45urn z7&xt-lQ4=srf4TQ0t9+4VtPkaj$WJSLHvz@h&{l4Ezs@Ggf&IU@SS^B!cBBv$FKGk zqD!~}*W88iNu37tAu)=Cf|9KG#5VFF(T9Xq^vw>C4~afx=%bQ2Ljv7)TyzQkim(jrc5;1wbYI~#C8+)4?XsYiK5xK+HrusC#kvih0ifT?Z z#LvRL%0?F*Q|!P