From 3c6196597b16929038d9328464306aec9cc2ada9 Mon Sep 17 00:00:00 2001 From: Luis Roberto Mercado Diaz Date: Mon, 5 Feb 2024 17:04:21 -0500 Subject: [PATCH] checkpointing changes Just changes --- .../active_learning/ss_active_learning.py | 96 ++++++++++++------- BML_project/main_checkpoints.py | 50 +++++++--- BML_project/models/ss_gp_model_checkpoint.py | 9 +- BML_project/utils_gp/data_loader.py | 27 +++--- 4 files changed, 116 insertions(+), 66 deletions(-) diff --git a/BML_project/active_learning/ss_active_learning.py b/BML_project/active_learning/ss_active_learning.py index 546c4ad..4c44836 100644 --- a/BML_project/active_learning/ss_active_learning.py +++ b/BML_project/active_learning/ss_active_learning.py @@ -12,27 +12,25 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -def label_samples(uncertain_samples, validation_data): - labels = [validation_data[sample_id]['label'] for sample_id in uncertain_samples] - return uncertain_samples, labels +def label_samples(uncertain_samples, validation_dataset): + labeled_samples = [(sample, validation_dataset[sample]['label']) for sample in uncertain_samples] + return labeled_samples -def stochastic_uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_batches, n_components=2): +def stochastic_uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_batches, device): gp_model.eval() gp_likelihood.eval() uncertain_sample_indices = [] - sampled_batches = random.sample(list(val_loader), n_batches) # Randomly sample n_batches from val_loader - + sampled_batches = random.sample(list(val_loader), min(n_batches, len(val_loader))) # Ensure we don't exceed available batches + with torch.no_grad(): for batch in sampled_batches: - # reduced_data = apply_tsne(batch['data'].reshape(batch['data'].size(0), -1), n_components=n_components) - # reduced_data_tensor = torch.Tensor(reduced_data).to(device) - reduced_data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device) - predictions = gp_likelihood(gp_model(reduced_data_tensor)) - var = predictions.variance - top_indices = torch.argsort(-var.flatten())[:n_samples] - uncertain_sample_indices.extend(top_indices.cpu().numpy()) - - return uncertain_sample_indices[:n_samples] + data = batch['data'].to(device) + predictions = gp_likelihood(gp_model(data)) + variances = predictions.variance + _, top_uncertain_indices = torch.topk(variances.view(-1), n_samples) + uncertain_sample_indices.extend(top_uncertain_indices.cpu().numpy()) + + return list(set(uncertain_sample_indices[:n_samples])) # Remove duplicates and slice if needed # def uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_components=2): # gp_model.eval() @@ -87,34 +85,58 @@ def stochastic_compare_kmeans_gp_predictions(kmeans_model, gp_model, data_loader import random -def refined_uncertainty_sampling(gp_model, gp_likelihood, kmeans_model, data_loader, n_samples, n_batches, uncertainty_threshold=0.2): - gp_model.eval() - gp_likelihood.eval() - uncertain_sample_indices = [] +# def refined_uncertainty_sampling(gp_model, gp_likelihood, kmeans_model, data_loader, n_samples, n_batches, uncertainty_threshold=0.2): +# gp_model.eval() +# gp_likelihood.eval() +# uncertain_sample_indices = [] - # Calculate the total number of batches in the DataLoader - total_batches = len(data_loader) +# # Calculate the total number of batches in the DataLoader +# total_batches = len(data_loader) - # Ensure that n_batches does not exceed total_batches - n_batches = min(n_batches, total_batches) +# # Ensure that n_batches does not exceed total_batches +# n_batches = min(n_batches, total_batches) - # Randomly sample n_batches from data_loader - sampled_batches = random.sample(list(data_loader), n_batches) +# # Randomly sample n_batches from data_loader +# sampled_batches = random.sample(list(data_loader), n_batches) +# with torch.no_grad(): +# for batch in sampled_batches: +# data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device) +# gp_predictions = gp_likelihood(gp_model(data_tensor)) +# kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy()) + +# # Calculate the difference between K-means and GP predictions +# disagreement = (gp_predictions.mean.argmax(dim=-1).cpu().numpy() != kmeans_predictions).astype(int) + +# # Calculate uncertainty based on variance of GP predictions +# uncertainty = gp_predictions.variance.cpu().numpy() + +# # Select samples where the disagreement is high and the model is uncertain +# uncertain_indices = np.where((disagreement > 0) & (uncertainty > uncertainty_threshold))[0] +# uncertain_sample_indices.extend(uncertain_indices) + +# return uncertain_sample_indices[:n_samples] + +def refined_uncertainty_sampling(gp_model, gp_likelihood, kmeans_model, val_loader, n_samples, n_batches, uncertainty_threshold, device): + gp_model.eval() + gp_likelihood.eval() + uncertain_samples = [] + + sampled_batches = random.sample(list(val_loader), min(n_batches, len(val_loader))) + with torch.no_grad(): for batch in sampled_batches: - data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device) - gp_predictions = gp_likelihood(gp_model(data_tensor)) - kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy()) + data = batch['data'].to(device) + gp_predictions = gp_likelihood(gp_model(data)) + variances = gp_predictions.variance + high_uncertainty_indices = variances.view(-1) > uncertainty_threshold - # Calculate the difference between K-means and GP predictions - disagreement = (gp_predictions.mean.argmax(dim=-1).cpu().numpy() != kmeans_predictions).astype(int) + kmeans_predictions = kmeans_model.predict(data.cpu().numpy()) + gp_class_predictions = gp_predictions.mean.argmax(-1).cpu().numpy() - # Calculate uncertainty based on variance of GP predictions - uncertainty = gp_predictions.variance.cpu().numpy() + disagreement = gp_class_predictions != kmeans_predictions + uncertain_indices = np.where(disagreement & high_uncertainty_indices.cpu().numpy())[0] - # Select samples where the disagreement is high and the model is uncertain - uncertain_indices = np.where((disagreement > 0) & (uncertainty > uncertainty_threshold))[0] - uncertain_sample_indices.extend(uncertain_indices) - - return uncertain_sample_indices[:n_samples] + uncertain_samples.extend(uncertain_indices) + + return uncertain_samples[:n_samples] diff --git a/BML_project/main_checkpoints.py b/BML_project/main_checkpoints.py index 9098257..cb6299c 100644 --- a/BML_project/main_checkpoints.py +++ b/BML_project/main_checkpoints.py @@ -10,7 +10,7 @@ from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples from models.ss_gp_model import MultitaskGPModel, train_gp_model from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data -from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions +from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions, label_samples from utils.visualization import plot_comparative_results, plot_training_performance, plot_results device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -27,10 +27,11 @@ def main(): # Attempt to resume from the last saved batch index if a dataset checkpoint exists dataset_checkpoint_path = 'dataset_checkpoint.pt' + current_batch_index = 0 # Default to start from beginning if no checkpoint if os.path.exists(dataset_checkpoint_path): - for loader in [train_loader, val_loader, test_loader]: - loader.dataset.load_checkpoint(dataset_checkpoint_path) - print(f"Resuming from batch index {loader.dataset.get_current_batch_index()}") + checkpoint = torch.load(dataset_checkpoint_path) + current_batch_index = checkpoint.get('current_batch_index', 0) + print(f"Resuming from batch index {current_batch_index}") kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device) @@ -42,7 +43,19 @@ def main(): } # Initial model training - model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes) + # model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes) + + # Call the train_gp_model function with batch_size + model, likelihood, training_metrics = train_gp_model( + train_loader=train_loader, + val_loader=val_loader, + num_iterations=50, + n_classes=n_classes, + patience=10, + checkpoint_path='model_checkpoint_full.pt', + resume_training=True, # or False + batch_size=batch_size + ) # Save the training metrics for future visualization results['train_loss'].extend(training_metrics['train_loss']) @@ -52,23 +65,30 @@ def main(): active_learning_iterations = 10 # Active Learning Iterations - for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True): + for iteration in range(active_learning_iterations): + print(f"Active Learning Iteration: {iteration+1}/{active_learning_iterations}") # Perform uncertainty sampling to select new samples from the validation set - uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5) + uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=50, n_batches=5, device=device) + + labeled_samples = label_samples(uncertain_sample_indices, val_loader.dataset) # Update the training loader with uncertain samples - train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size) + train_loader = update_train_loader_with_uncertain_samples(train_loader, labeled_samples, batch_size) # Optionally, save the dataset state at intervals or after certain conditions - train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None) # Here, manage the index as needed + train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None) + # model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt') + # Re-train the model with the updated training data - model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt') - + model, likelihood, val_metrics = train_gp_model( + train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, + checkpoint_path='model_checkpoint_last.pt', resume_training=True, batch_size=batch_size + ) # Store the validation metrics after each active learning iteration - results['validation_metrics']['precision'].append(val_metrics['precision']) - results['validation_metrics']['recall'].append(val_metrics['recall']) - results['validation_metrics']['f1'].append(val_metrics['f1']) + results['Active_validation_metrics']['precision'].append(val_metrics['precision']) + results['Active_validation_metrics']['recall'].append(val_metrics['recall']) + results['Active_validation_metrics']['f1'].append(val_metrics['f1']) # Compare K-Means with GP model predictions after retraining gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device) @@ -84,7 +104,7 @@ def main(): plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels) # Visualization of results - plot_training_performance(results['train_loss'], results['validation_metrics']) + plot_training_performance(results['train_loss'], results['Active_validation_metrics']) plot_results(results['test_metrics']) # Print final test metrics diff --git a/BML_project/models/ss_gp_model_checkpoint.py b/BML_project/models/ss_gp_model_checkpoint.py index 8249e46..282c73d 100644 --- a/BML_project/models/ss_gp_model_checkpoint.py +++ b/BML_project/models/ss_gp_model_checkpoint.py @@ -69,7 +69,7 @@ def forward(self, x): return latent_pred -def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False): +def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False, batch_size=1024): model = MultitaskGPModel().to(device) likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.1) @@ -77,14 +77,16 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat # Load checkpoint if resuming training start_epoch = 0 + current_batch_index = 0 # Default value in case it's not found in the checkpoint if resume_training and os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) likelihood.load_state_dict(checkpoint['likelihood_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch - print(f"Resuming training from epoch {start_epoch}") - + current_batch_index = checkpoint.get('current_batch_index', 0) # Retrieve the last batch index if it exists + print(f"Resuming training from epoch {start_epoch}, batch index {current_batch_index}") + best_val_loss = float('inf') epochs_no_improve = 0 metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []} @@ -150,6 +152,7 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat 'likelihood_state_dict': likelihood.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_val_loss': best_val_loss, + 'current_batch_index': train_loader.dataset.get_current_batch_index() # Get and save the current batch index # Include other metrics as needed }, checkpoint_path) diff --git a/BML_project/utils_gp/data_loader.py b/BML_project/utils_gp/data_loader.py index 54caf28..67169b1 100644 --- a/BML_project/utils_gp/data_loader.py +++ b/BML_project/utils_gp/data_loader.py @@ -95,7 +95,7 @@ def split_uids(): return clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled class CustomDataset(Dataset): - def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False): + def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False, start_idx=0): self.data_path = data_path self.labels_path = labels_path self.UIDs = UIDs @@ -104,6 +104,7 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format=' self.read_all_labels = read_all_labels self.transforms = ToTensor() self.refresh_dataset() + self.start_idx = start_idx # Add this line # Initialize the current batch index to None self.current_batch_index = None @@ -143,12 +144,13 @@ def load_checkpoint(self, checkpoint_path): self.current_batch_index = checkpoint.get('current_batch_index') def __getitem__(self, idx): - segment_name = self.segment_names[idx] + actual_idx = idx + self.start_idx + segment_name = self.segment_names[actual_idx] label = self.labels[segment_name] - if hasattr(self, 'all_data') and idx < len(self.all_data): + if hasattr(self, 'all_data') and actual_idx < len(self.all_data): # Data is stored in memory - time_freq_tensor = self.all_data[idx] + time_freq_tensor = self.all_data[actual_idx] else: # Load data on-the-fly based on the segment_name time_freq_tensor = self.load_data(segment_name) @@ -163,6 +165,9 @@ def set_current_batch_index(self, index): def get_current_batch_index(self): return self.current_batch_index + def set_start_idx(self, index): + self.start_idx = index + def add_data_label_pair(self, data, label): # Assign a unique ID or name for the new data new_id = len(self.segment_names) @@ -232,8 +237,8 @@ def standard_scaling(self, data): data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) return torch.Tensor(data) -def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=False, drop_last=False, num_workers=4): - dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels) +def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=False, drop_last=False, num_workers=4, start_idx=0): + dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels, start_idx=start_idx) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=num_workers, prefetch_factor=2) return dataloader @@ -268,12 +273,12 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False): return data_path, labels_path, saving_path # Function to extract and preprocess data -def preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size, read_all_labels=False): - # Extracts paths and loads data into train, validation, and test loaders +def preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size, read_all_labels=False, current_batch_index=0): + start_idx = current_batch_index * batch_size data_path, labels_path, saving_path = get_data_paths(data_format) - train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels) - val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels) - test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels) + train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels, start_idx=start_idx) + val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels, start_idx=start_idx) + test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels, start_idx=start_idx) return train_loader, val_loader, test_loader def map_samples_to_uids(uncertain_sample_indices, dataset):