Skip to content

Commit

Permalink
Tried to solve the error. Running it on my Linux computer now
Browse files Browse the repository at this point in the history
  • Loading branch information
doh16101 committed Jan 25, 2024
1 parent ddbd109 commit 781947f
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ VAE.py
model_checkpoint.pt
GP_original_data.py
Attention_network.py
*.pt
Binary file not shown.
4 changes: 2 additions & 2 deletions BML_project/active_learning/ss_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def run_minibatch_kmeans(data_loader, n_clusters, device, batch_size=100):
# Iterate through data_loader and fit MiniBatchKMeans
for batch in data_loader:
data = batch['data'].view(batch['data'].size(0), -1).to(device).cpu().numpy()
# minibatch_kmeans.partial_fit(data)
minibatch_kmeans.fit(data) # Dong, 01/22/2024: Debug
minibatch_kmeans.partial_fit(data)
# minibatch_kmeans.fit(data) # Dong, 01/22/2024: Debug

return minibatch_kmeans

Expand Down
61 changes: 58 additions & 3 deletions BML_project/ss_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,44 @@
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions
from utils_gp.visualization import plot_comparative_results, plot_training_performance, plot_results
import os
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CheckpointManager:
def __init__(self, checkpoint_dir):
self.checkpoint_dir = checkpoint_dir # Store the directory path for checkpoints
if not os.path.exists(checkpoint_dir): # Check if the directory exists
os.makedirs(checkpoint_dir) # Create the directory if it does not exist

def save_checkpoint(self, loader_name, iteration, additional_state):
# Construct the checkpoint file path using the loader name
checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl")
checkpoint = {
'iteration': iteration, # Store the current iteration
'additional_state': additional_state # Store any additional state information
}
with open(checkpoint_path, 'wb') as f: # Open the file in write-binary mode
pickle.dump(checkpoint, f) # Serialize the checkpoint dictionary to the file

def load_checkpoint(self, loader_name):
# Construct the checkpoint file path using the loader name
checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl")
try:
with open(checkpoint_path, 'rb') as f: # Open the file in read-binary mode
return pickle.load(f) # Deserialize the checkpoint file and return it
except FileNotFoundError: # Handle the case where the checkpoint file does not exist
return None # Return None if the file is not found

def main():
# Set parameters like n_classes, batch_size, etc.
n_classes = 4
batch_size = 1024
clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids()
data_format = 'pt'
# Preprocess data
train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size)
train_loader, val_loader, test_loader, saving_path = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size)
print('Debug: len(train_loader)',len(train_loader))

kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device)
Expand All @@ -43,9 +70,27 @@ def main():
results['validation_metrics']['f1'].extend(training_metrics['f1_score'])
# results['validation_metrics']['auc_roc'].extend(training_metrics['auc_roc'])

# --- Dong: copied from GP_Original_Checkpoint.py ---
# Initialize the CheckpointManager
checkpoint_manager = CheckpointManager(saving_path)

# Attempt to load a training checkpoint
train_checkpoint = checkpoint_manager.load_checkpoint('train')
start_iteration = train_checkpoint['iteration'] if train_checkpoint else 0
# Dong, 01/25/2024: save it first before entering the active learning.
additional_state = {
'model_state': model.state_dict(),
'likelihood':likelihood,
'val_loader':val_loader,
'train_loader':train_loader
# Include other states like optimizer, scheduler, etc.
}
checkpoint_manager.save_checkpoint('train', start_iteration, additional_state)
# ---------------------------------------------------

active_learning_iterations = 10
# Active Learning Iterations
for iteration in tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True):
for iteration in tqdm(range(start_iteration,active_learning_iterations), desc='Active Learning', unit='iteration', leave=True):
# Perform uncertainty sampling to select new samples from the validation set
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5)

Expand All @@ -58,9 +103,19 @@ def main():
# Store the validation metrics after each active learning iteration
results['validation_metrics']['precision'].append(val_metrics['precision'])
results['validation_metrics']['recall'].append(val_metrics['recall'])
results['validation_metrics']['f1'].append(val_metrics['f1'])
results['validation_metrics']['f1'].append(val_metrics['f1_score'])
# results['validation_metrics']['auc_roc'].append(val_metrics['auc_roc'])

# Save checkpoint at the end of each iteration
additional_state = {
'model_state': model.state_dict(),
'likelihood':likelihood,
'val_loader':val_loader,
'train_loader':train_loader
# Include other states like optimizer, scheduler, etc.
}
checkpoint_manager.save_checkpoint('train', iteration, additional_state)

# Compare K-Means with GP model predictions after retraining
gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device)

Expand Down
Binary file modified BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc
Binary file not shown.
8 changes: 4 additions & 4 deletions BML_project/utils_gp/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def split_uids():
print(f'Clinical trial: selected {len(clinical_trial_test)} UIDs for testing {clinical_trial_test}')
print(f'Clinical trial: selected {len(clinical_trial_unlabeled)} UIDs for unlabeled {clinical_trial_unlabeled}')

clinical_trial_train = [clinical_trial_train[0]]
clinical_trial_test = [clinical_trial_test[0]]
clinical_trial_unlabeled = clinical_trial_unlabeled[0:4]
# clinical_trial_train = [clinical_trial_train[0]]
# clinical_trial_test = [clinical_trial_test[0]]
# clinical_trial_unlabeled = clinical_trial_unlabeled[0:4]

return clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled

Expand Down Expand Up @@ -260,7 +260,7 @@ def preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clin
train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels)
val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels)
test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels)
return train_loader, val_loader, test_loader
return train_loader, val_loader, test_loader, saving_path

def map_samples_to_uids(uncertain_sample_indices, dataset):
"""
Expand Down

0 comments on commit 781947f

Please sign in to comment.