diff --git a/.gitignore b/.gitignore index 11d1435..233cecd 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ VAE.py model_checkpoint.pt GP_original_data.py Attention_network.py +*.pt diff --git a/BML_project/active_learning/__pycache__/ss_active_learning.cpython-311.pyc b/BML_project/active_learning/__pycache__/ss_active_learning.cpython-311.pyc index 9504984..bfc5c60 100644 Binary files a/BML_project/active_learning/__pycache__/ss_active_learning.cpython-311.pyc and b/BML_project/active_learning/__pycache__/ss_active_learning.cpython-311.pyc differ diff --git a/BML_project/active_learning/ss_active_learning.py b/BML_project/active_learning/ss_active_learning.py index a488281..0de3533 100644 --- a/BML_project/active_learning/ss_active_learning.py +++ b/BML_project/active_learning/ss_active_learning.py @@ -55,8 +55,8 @@ def run_minibatch_kmeans(data_loader, n_clusters, device, batch_size=100): # Iterate through data_loader and fit MiniBatchKMeans for batch in data_loader: data = batch['data'].view(batch['data'].size(0), -1).to(device).cpu().numpy() - # minibatch_kmeans.partial_fit(data) - minibatch_kmeans.fit(data) # Dong, 01/22/2024: Debug + minibatch_kmeans.partial_fit(data) + # minibatch_kmeans.fit(data) # Dong, 01/22/2024: Debug return minibatch_kmeans diff --git a/BML_project/ss_main.py b/BML_project/ss_main.py index 2716179..73d6fb6 100644 --- a/BML_project/ss_main.py +++ b/BML_project/ss_main.py @@ -11,9 +11,36 @@ from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions from utils_gp.visualization import plot_comparative_results, plot_training_performance, plot_results +import os +import pickle device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +class CheckpointManager: + def __init__(self, checkpoint_dir): + self.checkpoint_dir = checkpoint_dir # Store the directory path for checkpoints + if not os.path.exists(checkpoint_dir): # Check if the directory exists + os.makedirs(checkpoint_dir) # Create the directory if it does not exist + + def save_checkpoint(self, loader_name, iteration, additional_state): + # Construct the checkpoint file path using the loader name + checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl") + checkpoint = { + 'iteration': iteration, # Store the current iteration + 'additional_state': additional_state # Store any additional state information + } + with open(checkpoint_path, 'wb') as f: # Open the file in write-binary mode + pickle.dump(checkpoint, f) # Serialize the checkpoint dictionary to the file + + def load_checkpoint(self, loader_name): + # Construct the checkpoint file path using the loader name + checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl") + try: + with open(checkpoint_path, 'rb') as f: # Open the file in read-binary mode + return pickle.load(f) # Deserialize the checkpoint file and return it + except FileNotFoundError: # Handle the case where the checkpoint file does not exist + return None # Return None if the file is not found + def main(): # Set parameters like n_classes, batch_size, etc. n_classes = 4 @@ -21,7 +48,7 @@ def main(): clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids() data_format = 'pt' # Preprocess data - train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size) + train_loader, val_loader, test_loader, saving_path = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size) print('Debug: len(train_loader)',len(train_loader)) kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device) @@ -43,9 +70,27 @@ def main(): results['validation_metrics']['f1'].extend(training_metrics['f1_score']) # results['validation_metrics']['auc_roc'].extend(training_metrics['auc_roc']) + # --- Dong: copied from GP_Original_Checkpoint.py --- + # Initialize the CheckpointManager + checkpoint_manager = CheckpointManager(saving_path) + + # Attempt to load a training checkpoint + train_checkpoint = checkpoint_manager.load_checkpoint('train') + start_iteration = train_checkpoint['iteration'] if train_checkpoint else 0 + # Dong, 01/25/2024: save it first before entering the active learning. + additional_state = { + 'model_state': model.state_dict(), + 'likelihood':likelihood, + 'val_loader':val_loader, + 'train_loader':train_loader + # Include other states like optimizer, scheduler, etc. + } + checkpoint_manager.save_checkpoint('train', start_iteration, additional_state) + # --------------------------------------------------- + active_learning_iterations = 10 # Active Learning Iterations - for iteration in tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True): + for iteration in tqdm(range(start_iteration,active_learning_iterations), desc='Active Learning', unit='iteration', leave=True): # Perform uncertainty sampling to select new samples from the validation set uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5) @@ -58,9 +103,19 @@ def main(): # Store the validation metrics after each active learning iteration results['validation_metrics']['precision'].append(val_metrics['precision']) results['validation_metrics']['recall'].append(val_metrics['recall']) - results['validation_metrics']['f1'].append(val_metrics['f1']) + results['validation_metrics']['f1'].append(val_metrics['f1_score']) # results['validation_metrics']['auc_roc'].append(val_metrics['auc_roc']) + # Save checkpoint at the end of each iteration + additional_state = { + 'model_state': model.state_dict(), + 'likelihood':likelihood, + 'val_loader':val_loader, + 'train_loader':train_loader + # Include other states like optimizer, scheduler, etc. + } + checkpoint_manager.save_checkpoint('train', iteration, additional_state) + # Compare K-Means with GP model predictions after retraining gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device) diff --git a/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc b/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc index e131c15..f58fbb2 100644 Binary files a/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc and b/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc differ diff --git a/BML_project/utils_gp/data_loader.py b/BML_project/utils_gp/data_loader.py index 921467d..5b84b8d 100644 --- a/BML_project/utils_gp/data_loader.py +++ b/BML_project/utils_gp/data_loader.py @@ -98,9 +98,9 @@ def split_uids(): print(f'Clinical trial: selected {len(clinical_trial_test)} UIDs for testing {clinical_trial_test}') print(f'Clinical trial: selected {len(clinical_trial_unlabeled)} UIDs for unlabeled {clinical_trial_unlabeled}') - clinical_trial_train = [clinical_trial_train[0]] - clinical_trial_test = [clinical_trial_test[0]] - clinical_trial_unlabeled = clinical_trial_unlabeled[0:4] + # clinical_trial_train = [clinical_trial_train[0]] + # clinical_trial_test = [clinical_trial_test[0]] + # clinical_trial_unlabeled = clinical_trial_unlabeled[0:4] return clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled @@ -260,7 +260,7 @@ def preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clin train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels) val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels) test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels) - return train_loader, val_loader, test_loader + return train_loader, val_loader, test_loader, saving_path def map_samples_to_uids(uncertain_sample_indices, dataset): """