From 5de15262aa1d74421fc23b97d7c2e2410808c3d7 Mon Sep 17 00:00:00 2001 From: Luis Roberto Mercado Diaz Date: Sun, 21 Apr 2024 16:55:05 -0400 Subject: [PATCH] UPDATES --- main_darren_v1.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/main_darren_v1.py b/main_darren_v1.py index 29ec642..6178aff 100644 --- a/main_darren_v1.py +++ b/main_darren_v1.py @@ -95,7 +95,8 @@ def load_data(data_path, labels_path, batch_size, binary=False): return dataloader def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, - checkpoint_path='model_checkpoint.pt', resume_training=False): + checkpoint_path='model_checkpoint.pt', data_checkpoint_path='data_checkpoint.pt', + resume_training=False, plot_path='training_plot.png'): model = MultitaskGPModel().to(device) likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.1) @@ -136,9 +137,9 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat # Stochastic validation model.eval() likelihood.eval() + val_loss = 0.0 # Initialize val_loss here with torch.no_grad(): val_indices = torch.randperm(len(val_loader.dataset))[:int(0.1 * len(val_loader.dataset))] - val_loss = 0.0 val_labels = [] val_predictions = [] for idx in val_indices: @@ -212,6 +213,7 @@ def evaluate_gp_model(test_loader, model, likelihood, n_classes=4): return metrics def main(): + print("Step 1: Loading paths and parameters") # Paths base_path = r"\\grove.ad.uconn.edu\\research\\ENGR_Chon\Darren\\NIH_Pulsewatch" smote_type = 'Cassey5k_SMOTE' @@ -231,23 +233,37 @@ def main(): else: n_classes = 3 patience = round(n_epochs / 10) if n_epochs > 50 else 5 - save = True resume_checkpoint_path = None batch_size = 256 + print("Step 2: Loading data") # Data loading train_loader = load_data(data_path_train, labels_path_train, batch_size, binary) val_loader = load_data(data_path_val, labels_path_val, batch_size, binary) test_loader = load_data(data_path_test, labels_path_test, batch_size, binary) - # Training and validation + print("Step 3: Loading data checkpoints") + # Data loading with checkpointing + data_checkpoint_path = 'data_checkpoint.pt' + if os.path.exists(data_checkpoint_path): + train_loader.dataset.load_checkpoint(data_checkpoint_path) + val_loader.dataset.load_checkpoint(data_checkpoint_path) + test_loader.dataset.load_checkpoint(data_checkpoint_path) + + print("Step 4: Training and validation") + # Training and validation with checkpointing and plotting + model_checkpoint_path = 'model_checkpoint.pt' + plot_path = 'training_plot.png' start_time = time.time() model, likelihood, metrics = train_gp_model(train_loader, val_loader, n_epochs, - n_classes, patience, save) + n_classes, patience, + model_checkpoint_path, data_checkpoint_path, + resume_checkpoint_path is not None, plot_path) end_time = time.time() time_passed = end_time - start_time print('\nTraining and validation took %.2f minutes' % (time_passed / 60)) + print("Step 5: Evaluation") # Evaluation start_time = time.time() test_metrics = evaluate_gp_model(test_loader, model, likelihood, n_classes) @@ -255,6 +271,7 @@ def main(): time_passed = end_time - start_time print('\nTesting took %.2f seconds' % time_passed) + print("Step 6: Printing test metrics") print('Test Metrics:') print('Precision: %.4f' % test_metrics['precision']) print('Recall: %.4f' % test_metrics['recall'])