diff --git a/BML_project/debugging_zone/enhanced_data_loader.py b/BML_project/debugging_zone/enhanced_data_loader.py new file mode 100644 index 0000000..e69de29 diff --git a/BML_project/main_checkpoints.py b/BML_project/main_checkpoints.py new file mode 100644 index 0000000..9098257 --- /dev/null +++ b/BML_project/main_checkpoints.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Feb 1 19:43:31 2024 + +@author: lrm22005 +""" +import os +import tqdm +import torch +from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples +from models.ss_gp_model import MultitaskGPModel, train_gp_model +from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data +from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions +from utils.visualization import plot_comparative_results, plot_training_performance, plot_results + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def main(): + # Set parameters like n_classes, batch_size, etc. + n_classes = 4 + batch_size = 1024 + clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids() + data_format = 'pt' + + # Preprocess data + train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size) + + # Attempt to resume from the last saved batch index if a dataset checkpoint exists + dataset_checkpoint_path = 'dataset_checkpoint.pt' + if os.path.exists(dataset_checkpoint_path): + for loader in [train_loader, val_loader, test_loader]: + loader.dataset.load_checkpoint(dataset_checkpoint_path) + print(f"Resuming from batch index {loader.dataset.get_current_batch_index()}") + + kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device) + + # Initialize result storage + results = { + 'train_loss': [], + 'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []}, + 'test_metrics': None + } + + # Initial model training + model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes) + + # Save the training metrics for future visualization + results['train_loss'].extend(training_metrics['train_loss']) + results['validation_metrics']['precision'].extend(training_metrics['precision']) + results['validation_metrics']['recall'].extend(training_metrics['recall']) + results['validation_metrics']['f1'].extend(training_metrics['f1_score']) + + active_learning_iterations = 10 + # Active Learning Iterations + for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True): + # Perform uncertainty sampling to select new samples from the validation set + uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5) + + # Update the training loader with uncertain samples + train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size) + + # Optionally, save the dataset state at intervals or after certain conditions + train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None) # Here, manage the index as needed + + # Re-train the model with the updated training data + model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt') + + # Store the validation metrics after each active learning iteration + results['validation_metrics']['precision'].append(val_metrics['precision']) + results['validation_metrics']['recall'].append(val_metrics['recall']) + results['validation_metrics']['f1'].append(val_metrics['f1']) + + # Compare K-Means with GP model predictions after retraining + gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device) + + plot_comparative_results(gp_vs_kmeans_data, original_labels) + + # Final evaluation on test set + test_metrics = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes) + test_kmeans_model = run_minibatch_kmeans(test_loader, n_clusters=n_classes, device=device) + + results['test_metrics'] = test_metrics + test_gp_vs_kmeans_data, test_original_labels = stochastic_compare_kmeans_gp_predictions(test_kmeans_model, model, test_loader, n_batches=5, device=device) + plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels) + + # Visualization of results + plot_training_performance(results['train_loss'], results['validation_metrics']) + plot_results(results['test_metrics']) + + # Print final test metrics + print("Final Test Metrics:", results['test_metrics']) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/BML_project/models/ss_gp_model_checkpoint.py b/BML_project/models/ss_gp_model_checkpoint.py new file mode 100644 index 0000000..8249e46 --- /dev/null +++ b/BML_project/models/ss_gp_model_checkpoint.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Feb 2 15:14:51 2024 + +@author: lrm22005 +""" +import os +import numpy as np +from tqdm import tqdm +import torch +import gpytorch +from sklearn.metrics import precision_recall_fscore_support, roc_auc_score +from sklearn.preprocessing import label_binarize + +num_latents = 6 # This should match the complexity of your data or the number of tasks +num_tasks = 4 # This should match the number of output classes or tasks +num_inducing_points = 50 # This is independent and should be sufficient for the input space + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class MultitaskGPModel(gpytorch.models.ApproximateGP): + def __init__(self): + # Let's use a different set of inducing points for each latent function + inducing_points = torch.rand(num_latents, num_inducing_points, 127 * 128) # Assuming flattened 128x128 images + + # We have to mark the CholeskyVariationalDistribution as batch + # so that we learn a variational distribution for each task + variational_distribution = gpytorch.variational.CholeskyVariationalDistribution( + inducing_points.size(-2), batch_shape=torch.Size([num_latents]) + ) + + # We have to wrap the VariationalStrategy in a LMCVariationalStrategy + # so that the output will be a MultitaskMultivariateNormal rather than a batch output + variational_strategy = gpytorch.variational.LMCVariationalStrategy( + gpytorch.variational.VariationalStrategy( + self, inducing_points, variational_distribution, learn_inducing_locations=True + ), + num_tasks=num_tasks, + num_latents=num_latents, + latent_dim=-1 + ) + + super().__init__(variational_strategy) + + # The mean and covariance modules should be marked as batch + # so we learn a different set of hyperparameters + self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents])) + self.covar_module = gpytorch.kernels.ScaleKernel( + gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])), + batch_shape=torch.Size([num_latents]) + ) + + def forward(self, x): + # The forward function should be written as if we were dealing with each output + # dimension in batch + # Ensure x is correctly shaped. It should have the same last dimension size as inducing_points + # x should be reshaped or sliced to have the shape [?, 1] where ? can be any size + # For example, if x originally has shape [N, D], and D != 1, you need to modify x accordingly + # print(f"Input shape: {x.shape}") + # x = x.view(x.size(0), -1) # Flattening the images + # print(f"Input shape after flattening: {x.shape}") # Debugging input shape + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + + # Debugging: Print shapes of intermediate outputs + # print(f"Mean shape: {mean_x.shape}, Covariance shape: {covar_x.shape}") + latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) + # print(f"Latent prediction shape: {latent_pred.mean.shape}, {latent_pred.covariance_matrix.shape}") + + return latent_pred + +def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False): + model = MultitaskGPModel().to(device) + likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) + + # Load checkpoint if resuming training + start_epoch = 0 + if resume_training and os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model_state_dict']) + likelihood.load_state_dict(checkpoint['likelihood_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch + print(f"Resuming training from epoch {start_epoch}") + + best_val_loss = float('inf') + epochs_no_improve = 0 + metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []} + + for epoch in range(start_epoch, num_iterations): + for batch_index, train_batch in enumerate(train_loader): + model.train() + likelihood.train() + optimizer.zero_grad() + train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device) + train_y = train_batch['label'].to(device) + output = model(train_x) + loss = -mll(output, train_y) + metrics['train_loss'].append(loss.item()) + loss.backward() + optimizer.step() + + # Stochastic validation + model.eval() + likelihood.eval() + with torch.no_grad(): + val_indices = torch.randperm(len(val_loader.dataset))[:int(1 * len(val_loader.dataset))] + val_loss = 0.0 + val_labels = [] + val_predictions = [] + for idx in val_indices: + val_batch = val_loader.dataset[idx] + val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device) # Use reshape here + val_y = torch.tensor([val_batch['label']], device=device) + val_output = model(val_x) + val_loss_batch = -mll(val_output, val_y).sum() + val_loss += val_loss_batch.item() + val_labels.append(val_y.item()) + val_predictions.append(val_output.mean.argmax(dim=-1).item()) + + precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro') + # auc_roc = roc_auc_score(label_binarize(val_labels, classes=np.arange(n_classes)), + # label_binarize(val_predictions, classes=np.arange(n_classes)), + # multi_class='ovr') + + metrics['precision'].append(precision) + metrics['recall'].append(recall) + metrics['f1_score'].append(f1) + # metrics['auc_roc'].append(auc_roc) + val_loss /= len(val_indices) + + if val_loss < best_val_loss: + best_val_loss = val_loss + epochs_no_improve = 0 + torch.save({'model_state_dict': model.state_dict(), + 'likelihood_state_dict': likelihood.state_dict(), + 'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path) + else: + epochs_no_improve += 1 + if epochs_no_improve >= patience: + print(f"Early stopping triggered at epoch {epoch+1}") + break + + # Save checkpoint at the end of each epoch + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'likelihood_state_dict': likelihood.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'best_val_loss': best_val_loss, + # Include other metrics as needed + }, checkpoint_path) + + if epochs_no_improve >= patience: + print(f"Early stopping triggered at epoch {epoch+1}") + break + + # Optionally, load the best model at the end of training + if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model_state_dict']) + likelihood.load_state_dict(checkpoint['likelihood_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + return model, likelihood, metrics + +def semi_supervised_labeling(kmeans_model, gp_model, gp_likelihood, data_loader, confidence_threshold=0.8): + gp_model.eval() + gp_likelihood.eval() + labeled_samples = [] + + with torch.no_grad(): + for batch in data_loader: + data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device) + kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy()) + gp_predictions = gp_likelihood(gp_model(data_tensor)) + + # Use GP predictions where the model is confident + confident_indices = gp_predictions.confidence().cpu().numpy() > confidence_threshold + for i, confident in enumerate(confident_indices): + if confident: + labeled_samples.append((data_tensor[i], gp_predictions.mean.argmax(dim=-1)[i].item())) + else: + labeled_samples.append((data_tensor[i], kmeans_predictions[i])) + + return labeled_samples + +def calculate_elbo(model, likelihood, data_loader): + """ + Calculates the ELBO (Evidence Lower Bound) score for the model on the given data. + + Args: + - model: The trained Gaussian Process model. + - likelihood: The likelihood associated with the GP model. + - data_loader: DataLoader providing the data over which to calculate ELBO. + + Returns: + - elbo_score: The calculated ELBO score. + """ + model.eval() + likelihood.eval() + mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(data_loader.dataset)) + + with torch.no_grad(): + elbo_score = 0.0 + for batch in data_loader: + train_x = batch['data'].reshape(batch['data'].size(0), -1).to(device) + train_y = batch['label'].to(device) + output = model(train_x) + # Calculate the ELBO as the negative loss + elbo_score += -mll(output, train_y).sum().item() + + # Average the ELBO over all data samples + elbo_score /= len(data_loader.dataset) + + return elbo_score diff --git a/BML_project/utils_gp/data_loader.py b/BML_project/utils_gp/data_loader.py index 14ed5fe..54caf28 100644 --- a/BML_project/utils_gp/data_loader.py +++ b/BML_project/utils_gp/data_loader.py @@ -105,6 +105,9 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format=' self.transforms = ToTensor() self.refresh_dataset() + # Initialize the current batch index to None + self.current_batch_index = None + def refresh_dataset(self): # Extract unique segment names and their corresponding labels self.segment_names, self.labels = self.extract_segment_names_and_labels() @@ -119,6 +122,25 @@ def add_uids(self, new_uids): def __len__(self): return len(self.segment_names) + + def save_checkpoint(self, checkpoint_path, current_batch_index=None): + checkpoint = { + 'segment_names': self.segment_names, + 'labels': self.labels, + 'UIDs': self.UIDs, + # Save the current batch index if provided + 'current_batch_index': current_batch_index if current_batch_index is not None else self.current_batch_index + } + torch.save(checkpoint, checkpoint_path) + + def load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path) + self.segment_names = checkpoint['segment_names'] + self.labels = checkpoint['labels'] + self.UIDs = checkpoint['UIDs'] + self.refresh_dataset() + # Load the current batch index if it exists in the checkpoint + self.current_batch_index = checkpoint.get('current_batch_index') def __getitem__(self, idx): segment_name = self.segment_names[idx] @@ -133,6 +155,14 @@ def __getitem__(self, idx): return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name} + # New method to set the current batch index + def set_current_batch_index(self, index): + self.current_batch_index = index + + # New method to get the current batch index + def get_current_batch_index(self): + return self.current_batch_index + def add_data_label_pair(self, data, label): # Assign a unique ID or name for the new data new_id = len(self.segment_names)