diff --git a/GP_Original_checkpoint.py b/GP_Original_checkpoint.py index 0bdceae..777f962 100644 --- a/GP_Original_checkpoint.py +++ b/GP_Original_checkpoint.py @@ -43,8 +43,8 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False): saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" else: # R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch - base_path = "R:\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch" - labels_base_path = "R:\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn" + base_path = r"\\grove.ad.uconn.edu\\research\\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch" + labels_base_path = r"\\grove.ad.uconn.edu\\research\\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn" saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case" if data_format == 'csv': data_path = os.path.join(base_path, "TFS_csv") diff --git a/main_darren_v1.py b/main_darren_v1.py new file mode 100644 index 0000000..2473be0 --- /dev/null +++ b/main_darren_v1.py @@ -0,0 +1,361 @@ +import os +import torch +import gpytorch +from sklearn.metrics import precision_recall_fscore_support, roc_auc_score +from sklearn.preprocessing import label_binarize +from torch.utils.data import Dataset, DataLoader +import numpy as np +import random +import time +import matplotlib.pyplot as plt + +# Seeds +torch.manual_seed(42) +np.random.seed(42) +random.seed(42) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +num_latents = 6 # This should match the complexity of your data or the number of tasks +num_tasks = 4 # This should match the number of output classes or tasks +num_inducing_points = 50 # This is independent and should be sufficient for the input space + +class MultitaskGPModel(gpytorch.models.ApproximateGP): + def __init__(self): + # Let's use a different set of inducing points for each latent function + inducing_points = torch.rand(num_latents, num_inducing_points, 128 * 128) # Assuming flattened 128x128 images + + # We have to mark the CholeskyVariationalDistribution as batch + # so that we learn a variational distribution for each task + variational_distribution = gpytorch.variational.CholeskyVariationalDistribution( + inducing_points.size(-2), batch_shape=torch.Size([num_latents]) + ) + + # We have to wrap the VariationalStrategy in a LMCVariationalStrategy + # so that the output will be a MultitaskMultivariateNormal rather than a batch output + variational_strategy = gpytorch.variational.LMCVariationalStrategy( + gpytorch.variational.VariationalStrategy( + self, inducing_points, variational_distribution, learn_inducing_locations=True + ), + num_tasks=num_tasks, + num_latents=num_latents, + latent_dim=-1 + ) + + super().__init__(variational_strategy) + + # The mean and covariance modules should be marked as batch + # so we learn a different set of hyperparameters + self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents])) + self.covar_module = gpytorch.kernels.ScaleKernel( + gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])), + batch_shape=torch.Size([num_latents]) + ) + + def forward(self, x): + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) + return latent_pred + +class CustomDataset(Dataset): + def __init__(self, data_path, labels_path, binary=False, start_idx=0): + self.data_path = data_path + self.labels_path = labels_path + self.binary = binary + self.start_idx = start_idx + self.segment_names, self.labels = self.extract_segment_names_and_labels() + + def __len__(self): + return len(self.segment_names) + + def __getitem__(self, idx): + actual_idx = (idx + self.start_idx) % len(self.segment_names) + segment_name = self.segment_names[actual_idx] + label = self.labels[segment_name] + data_tensor = torch.load(os.path.join(self.data_path, segment_name + '.pt')) + return {'data': data_tensor, 'label': label, 'segment_name': segment_name} + + def extract_segment_names_and_labels(self): + segment_names = [] + labels = {} + + with open(self.labels_path, 'r') as file: + lines = file.readlines() + for line in lines[1:]: # Skip the header line + segment_name, label = line.strip().split(',') + label = int(float(label)) # Convert the label to float first, then to int + if self.binary and label == 2: + label = 0 # Convert PAC/PVC to non-AF (0) for binary classification + segment_names.append(segment_name) + labels[segment_name] = label + + return segment_names, labels + + def set_start_idx(self, index): + self.start_idx = index + + def save_checkpoint(self, checkpoint_path): + checkpoint = { + 'segment_names': self.segment_names, + 'labels': self.labels, + 'start_idx': self.start_idx + } + torch.save(checkpoint, checkpoint_path) + + def load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path) + self.segment_names = checkpoint['segment_names'] + self.labels = checkpoint['labels'] + self.start_idx = checkpoint['start_idx'] + +def load_data(data_path, labels_path, batch_size, binary=False, start_idx=0): + dataset = CustomDataset(data_path, labels_path, binary, start_idx) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) + return dataloader + +def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, + checkpoint_path='model_checkpoint.pt', data_checkpoint_path='data_checkpoint.pt', + resume_training=False, plot_path='training_plot.png'): + model = MultitaskGPModel().to(device) + likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) + + start_epoch = 0 + if resume_training and os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model_state_dict']) + likelihood.load_state_dict(checkpoint['likelihood_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + start_epoch = checkpoint.get('epoch', 0) + + best_val_loss = float('inf') + epochs_no_improve = 0 + + metrics = { + 'precision': [], + 'recall': [], + 'f1_score': [], + 'auc_roc': [], + 'train_loss': [] + } + + # Initialize lists to store metrics for plotting + train_losses = [] + val_losses = [] + val_precisions = [] + val_recalls = [] + val_f1_scores = [] + val_auc_rocs = [] + + for epoch in range(start_epoch, num_iterations): + model.train() + likelihood.train() + for train_batch in train_loader: + optimizer.zero_grad() + train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device) + train_y = train_batch['label'].to(device) + output = model(train_x) + loss = -mll(output, train_y) + metrics['train_loss'].append(loss.item()) + loss.backward() + optimizer.step() + + # Append metrics to lists for plotting + train_losses.append(np.mean(metrics['train_loss'])) + val_losses.append(val_loss) + val_precisions.append(precision) + val_recalls.append(recall) + val_f1_scores.append(f1) + val_auc_rocs.append(auc_roc) + + # Stochastic validation + model.eval() + likelihood.eval() + with torch.no_grad(): + val_indices = torch.randperm(len(val_loader.dataset))[:int(0.1 * len(val_loader.dataset))] + val_loss = 0.0 + val_labels = [] + val_predictions = [] + for idx in val_indices: + val_batch = val_loader.dataset[idx] + val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device) + val_y = torch.tensor([val_batch['label']], device=device) + val_output = model(val_x) + val_loss_batch = -mll(val_output, val_y).sum() + val_loss += val_loss_batch.item() + val_labels.append(val_y.item()) + val_predictions.append(val_output.mean.argmax(dim=-1).item()) + + precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro') + auc_roc = roc_auc_score(label_binarize(val_labels, classes=range(n_classes)), + label_binarize(val_predictions, classes=range(n_classes)), + multi_class='ovr') + + metrics['precision'].append(precision) + metrics['recall'].append(recall) + metrics['f1_score'].append(f1) + metrics['auc_roc'].append(auc_roc) + val_loss /= len(val_indices) + + # Plot metrics + plt.figure(figsize=(12, 8)) + plt.subplot(2, 2, 1) + plt.plot(train_losses, label='Training Loss') + plt.plot(val_losses, label='Validation Loss') + plt.legend() + plt.title('Loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + + plt.subplot(2, 2, 2) + plt.plot(val_precisions, label='Precision') + plt.plot(val_recalls, label='Recall') + plt.plot(val_f1_scores, label='F1 Score') + plt.legend() + plt.title('Validation Metrics') + plt.xlabel('Epoch') + plt.ylabel('Metric') + + plt.subplot(2, 2, 3) + plt.plot(val_auc_rocs, label='AUC-ROC') + plt.legend() + plt.title('Validation AUC-ROC') + plt.xlabel('Epoch') + plt.ylabel('AUC-ROC') + + plt.tight_layout() + plt.savefig(plot_path) + plt.close() + + if val_loss < best_val_loss: + best_val_loss = val_loss + epochs_no_improve = 0 + torch.save({'model_state_dict': model.state_dict(), + 'likelihood_state_dict': likelihood.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'epoch': epoch}, checkpoint_path) + else: + epochs_no_improve += 1 + if epochs_no_improve >= patience: + print(f"Early stopping triggered at epoch {epoch+1}") + break + + if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model_state_dict']) + likelihood.load_state_dict(checkpoint['likelihood_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + # Save final model weights and other information + torch.save({ + 'model_state_dict': model.state_dict(), + 'likelihood_state_dict': likelihood.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'train_losses': train_losses, + 'val_losses': val_losses, + 'val_precisions': val_precisions, + 'val_recalls': val_recalls, + 'val_f1_scores': val_f1_scores, + 'val_auc_rocs': val_auc_rocs + }, 'final_model_info.pt') + + return model, likelihood, metrics + +def evaluate_gp_model(test_loader, model, likelihood, n_classes=4): + model.eval() + likelihood.eval() + test_labels = [] + test_predictions = [] + + with torch.no_grad(): + for test_batch in test_loader: + test_x = test_batch['data'].reshape(test_batch['data'].size(0), -1).to(device) + test_y = test_batch['label'].to(device) + test_output = model(test_x) + test_labels.extend(test_y.tolist()) + test_predictions.extend(test_output.mean.argmax(dim=-1).tolist()) + + precision, recall, f1, _ = precision_recall_fscore_support(test_labels, test_predictions, average='macro') + auc_roc = roc_auc_score(label_binarize(test_labels, classes=range(n_classes)), + label_binarize(test_predictions, classes=range(n_classes)), + multi_class='ovr') + + metrics = { + 'precision': precision, + 'recall': recall, + 'f1_score': f1, + 'auc_roc': auc_roc + } + + return metrics + +def main(): + print("Step 1: Loading paths and parameters") + # Paths + base_path = r"\\grove.ad.uconn.edu\\research\\ENGR_Chon\Darren\\NIH_Pulsewatch" + smote_type = 'Cassey5k_SMOTE' + split = 'holdout_60_10_30' + data_path_train = os.path.join(base_path, "TFS_pt", smote_type, split, "train") + data_path_val = os.path.join(base_path, "TFS_pt", smote_type, split, "validate") + data_path_test = os.path.join(base_path, "TFS_pt", smote_type, split, "test") + labels_path_train = os.path.join(base_path, "TFS_pt", smote_type, split, "Cassey5k_SMOTE_train_names_labels.csv") + labels_path_val = os.path.join(base_path, "TFS_pt", smote_type, split, "Cassey5k_SMOTE_validate_names_labels.csv") + labels_path_test = os.path.join(base_path, "TFS_pt", smote_type, split, "Cassey5k_SMOTE_test_names_labels.csv") + + # Parameters + binary = False + n_epochs = 100 + if binary: + n_classes = 2 + else: + n_classes = 3 + patience = round(n_epochs / 10) if n_epochs > 50 else 5 + resume_checkpoint_path = None + batch_size = 256 + + print("Step 2: Loading data") + # Data loading + train_loader = load_data(data_path_train, labels_path_train, batch_size, binary) + val_loader = load_data(data_path_val, labels_path_val, batch_size, binary) + test_loader = load_data(data_path_test, labels_path_test, batch_size, binary) + + print("Step 3: Loading data checkpoints") + # Data loading with checkpointing + data_checkpoint_path = 'data_checkpoint.pt' + if os.path.exists(data_checkpoint_path): + train_loader.dataset.load_checkpoint(data_checkpoint_path) + val_loader.dataset.load_checkpoint(data_checkpoint_path) + test_loader.dataset.load_checkpoint(data_checkpoint_path) + + print("Step 4: Training and validation") + # Training and validation with checkpointing and plotting + model_checkpoint_path = 'model_checkpoint.pt' + plot_path = 'training_plot.png' + start_time = time.time() + model, likelihood, metrics = train_gp_model(train_loader, val_loader, n_epochs, + n_classes, patience, + model_checkpoint_path, data_checkpoint_path, + resume_checkpoint_path is not None, plot_path) + end_time = time.time() + time_passed = end_time - start_time + print('\nTraining and validation took %.2f minutes' % (time_passed / 60)) + + print("Step 5: Evaluation") + # Evaluation + start_time = time.time() + test_metrics = evaluate_gp_model(test_loader, model, likelihood, n_classes) + end_time = time.time() + time_passed = end_time - start_time + print('\nTesting took %.2f seconds' % time_passed) + + print("Step 6: Printing test metrics") + print('Test Metrics:') + print('Precision: %.4f' % test_metrics['precision']) + print('Recall: %.4f' % test_metrics['recall']) + print('F1 Score: %.4f' % test_metrics['f1_score']) + print('AUC-ROC: %.4f' % test_metrics['auc_roc']) + +if __name__ == '__main__': + main() \ No newline at end of file