diff --git a/BML_project/active_learning/ss_active_learning.py b/BML_project/active_learning/ss_active_learning.py index 4c44836..d75b67f 100644 --- a/BML_project/active_learning/ss_active_learning.py +++ b/BML_project/active_learning/ss_active_learning.py @@ -80,7 +80,7 @@ def stochastic_compare_kmeans_gp_predictions(kmeans_model, gp_model, data_loader kmeans_predictions = kmeans_model.predict(data.cpu().numpy()) all_labels.append(labels.cpu().numpy()) all_data.append((gp_predictions, kmeans_predictions)) - + print(f"Processed batch size: {len(current_batch_labels)}, Cumulative original_labels size: {len(original_labels)}, Cumulative gp_predictions size: {len(gp_predictions)}") return all_data, np.concatenate(all_labels) import random diff --git a/BML_project/main_checkpoints_updated.py b/BML_project/main_checkpoints_updated.py new file mode 100644 index 0000000..234291c --- /dev/null +++ b/BML_project/main_checkpoints_updated.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Feb 7 15:34:31 2024 + +@author: lrm22005 +""" +import os +import tqdm +import torch +from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples +from models.ss_gp_model import MultitaskGPModel, train_gp_model +from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data +from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions, label_samples +from utils.visualization import plot_comparative_results, plot_training_performance, plot_results + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def main(): + n_classes = 4 + batch_size = 1024 + clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids() + data_format = 'pt' + + train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size) + + # Initialize result storage + results = { + 'train_loss': [], + 'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []}, + 'active_learning': {'validation_metrics': []}, # Store validation metrics for each active learning iteration + 'test_metrics': None + } + + # Initial model training + model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_full.pt') + + # Save initial training metrics + results['train_loss'].extend(training_metrics['train_loss']) + for metric in ['precision', 'recall', 'f1_score']: + results['validation_metrics'][metric].extend(training_metrics[metric]) + + active_learning_iterations = 10 + for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration'): + uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, device=device) + train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size) + + # Re-train the model with updated training data + model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt') + + # Store validation metrics for each active learning iteration + results['active_learning']['validation_metrics'].append(val_metrics) + + # Final evaluations + test_metrics = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes) + results['test_metrics'] = test_metrics + + # Visualization of results + plot_training_performance(results['train_loss'], results['validation_metrics']) + plot_results(results['test_metrics']) # Adjust this function to handle the structure of test_metrics + + print("Final Test Metrics:", results['test_metrics']) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/BML_project/models/ss_gp_model_checkpoint.py b/BML_project/models/ss_gp_model_checkpoint.py index 282c73d..c11f8cb 100644 --- a/BML_project/models/ss_gp_model_checkpoint.py +++ b/BML_project/models/ss_gp_model_checkpoint.py @@ -69,30 +69,40 @@ def forward(self, x): return latent_pred -def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False, batch_size=1024): +def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_checkpoint_path=None): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MultitaskGPModel().to(device) likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.1) mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) + best_val_loss = float('inf') + epochs_no_improve = 0 - # Load checkpoint if resuming training start_epoch = 0 - current_batch_index = 0 # Default value in case it's not found in the checkpoint - if resume_training and os.path.exists(checkpoint_path): - checkpoint = torch.load(checkpoint_path) + start_batch = 0 + + # Resume from checkpoint if specified + if resume_checkpoint_path is not None and os.path.exists(resume_checkpoint_path): + checkpoint = torch.load(resume_checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) likelihood.load_state_dict(checkpoint['likelihood_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch - current_batch_index = checkpoint.get('current_batch_index', 0) # Retrieve the last batch index if it exists - print(f"Resuming training from epoch {start_epoch}, batch index {current_batch_index}") - - best_val_loss = float('inf') - epochs_no_improve = 0 - metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []} - - for epoch in range(start_epoch, num_iterations): + start_epoch = checkpoint.get('epoch', 0) + start_batch = checkpoint.get('batch', 0) + print(f"Resuming training from epoch {start_epoch}, batch {start_batch}") + + metrics = { + 'precision': [], + 'recall': [], + 'f1_score': [], + 'train_loss': [] + } + + for epoch in tqdm.tqdm(range(start_epoch, num_iterations), desc='Training', unit='epoch', leave=False): for batch_index, train_batch in enumerate(train_loader): + if epoch == start_epoch and batch_index < start_batch: + continue # Skip batches until the saved batch index + model.train() likelihood.train() optimizer.zero_grad() @@ -104,6 +114,17 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat loss.backward() optimizer.step() + # Save checkpoint at intervals or based on other conditions + if (batch_index + 1) % 100 == 0: # Example condition + torch.save({ + 'epoch': epoch, + 'batch': batch_index, + 'model_state_dict': model.state_dict(), + 'likelihood_state_dict': likelihood.state_dict(), + 'optimizer_state_dict': optimizer.state_dict() + }, checkpoint_path) + print(f"Checkpoint saved at epoch {epoch}, batch {batch_index}") + # Stochastic validation model.eval() likelihood.eval() @@ -160,7 +181,7 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat print(f"Early stopping triggered at epoch {epoch+1}") break - # Optionally, load the best model at the end of training + # Ensure to load the latest model state after the loop in case of early stopping if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) diff --git a/BML_project/ss_main.py b/BML_project/ss_main.py index a610684..1b9cfce 100644 --- a/BML_project/ss_main.py +++ b/BML_project/ss_main.py @@ -6,7 +6,7 @@ """ import tqdm import torch -from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples +from utils_gp.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples from models.ss_gp_model import MultitaskGPModel, train_gp_model from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions @@ -71,6 +71,8 @@ def main(): results['test_metrics'] = test_metrics test_gp_vs_kmeans_data, test_original_labels = stochastic_compare_kmeans_gp_predictions(test_kmeans_model, model, test_loader, n_batches=5, device=device) + # Before calling confusion_matrix in plot_comparative_results function + print(f"Length of original_labels: {len(original_labels)}, Length of gp_predictions: {len(gp_predictions)}") plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels) # Visualization of results diff --git a/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc b/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc index 0b1a7eb..ed2626e 100644 Binary files a/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc and b/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc differ diff --git a/BML_project/utils_gp/data_loader.py b/BML_project/utils_gp/data_loader.py index 67169b1..fe3bc7e 100644 --- a/BML_project/utils_gp/data_loader.py +++ b/BML_project/utils_gp/data_loader.py @@ -103,34 +103,30 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format=' self.data_format = data_format self.read_all_labels = read_all_labels self.transforms = ToTensor() + self.start_idx = start_idx # Initial batch index to start from, useful for resuming training self.refresh_dataset() - self.start_idx = start_idx # Add this line - # Initialize the current batch index to None + # Initialize the current batch index to None, this could be used if you want to track batch progress within the dataset itself self.current_batch_index = None def refresh_dataset(self): - # Extract unique segment names and their corresponding labels self.segment_names, self.labels = self.extract_segment_names_and_labels() def add_uids(self, new_uids): - # Ensure new UIDs are unique and not already in the dataset unique_new_uids = [uid for uid in new_uids if uid not in self.UIDs] - - # Add unique new UIDs and refresh the dataset self.UIDs.extend(unique_new_uids) self.refresh_dataset() def __len__(self): return len(self.segment_names) - - def save_checkpoint(self, checkpoint_path, current_batch_index=None): + + def save_checkpoint(self, checkpoint_path): + # Enhanced to automatically include 'start_idx' in the checkpoint checkpoint = { 'segment_names': self.segment_names, 'labels': self.labels, 'UIDs': self.UIDs, - # Save the current batch index if provided - 'current_batch_index': current_batch_index if current_batch_index is not None else self.current_batch_index + 'start_idx': self.start_idx # Now also saving start_idx } torch.save(checkpoint, checkpoint_path) @@ -139,32 +135,28 @@ def load_checkpoint(self, checkpoint_path): self.segment_names = checkpoint['segment_names'] self.labels = checkpoint['labels'] self.UIDs = checkpoint['UIDs'] + # Now also loading and setting start_idx from checkpoint + self.start_idx = checkpoint.get('start_idx', 0) self.refresh_dataset() - # Load the current batch index if it exists in the checkpoint - self.current_batch_index = checkpoint.get('current_batch_index') def __getitem__(self, idx): - actual_idx = idx + self.start_idx + actual_idx = (idx + self.start_idx) % len(self.segment_names) # Adjust index based on start_idx and wrap around if needed segment_name = self.segment_names[actual_idx] label = self.labels[segment_name] if hasattr(self, 'all_data') and actual_idx < len(self.all_data): - # Data is stored in memory time_freq_tensor = self.all_data[actual_idx] else: - # Load data on-the-fly based on the segment_name time_freq_tensor = self.load_data(segment_name) return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name} - # New method to set the current batch index def set_current_batch_index(self, index): self.current_batch_index = index - # New method to get the current batch index def get_current_batch_index(self): return self.current_batch_index - + def set_start_idx(self, index): self.start_idx = index