Skip to content

Commit

Permalink
Cassey manually updated the data loader changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
doh16101 committed Apr 4, 2024
1 parent a9de770 commit 861e742
Show file tree
Hide file tree
Showing 4 changed files with 348 additions and 105 deletions.
22 changes: 22 additions & 0 deletions BML_project/models/Colab_example_dataloader_2024_04_04.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# R:\\ENGR_Chon\\Darren\\NIH_Pulsewatch\\Poincare_pt\\128x128\n",
"# Darren created the PT files again (because UID 120 has missing files in the original csv file)\n",
"# I need to prepare for my interview, and I will tar those PT files again and test your code on Colab later."
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
138 changes: 128 additions & 10 deletions BML_project/models/ss_gp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
@author: lrm22005
"""
import os
import numpy as np
from tqdm import tqdm
import torch
import gpytorch
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from sklearn.preprocessing import label_binarize
from utils_gp.data_loader import preprocess_data_train_val,preprocess_data_test
import time

num_latents = 6 # This should match the complexity of your data or the number of tasks
num_tasks = 4 # This should match the number of output classes or tasks
Expand Down Expand Up @@ -70,11 +73,49 @@ def forward(self, x):
return latent_pred


def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt'):
def train_gp_model(train_loader, val_loader, batch_size,\
data_format, clinical_trial_train, clinical_trial_test,\
clinical_trial_unlabeled,\
num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt',\
resume_training=False,\
datackpt_name = 'dataset_checkpoint.pt',modelckpt_name = 'model_checkpoint_full.pt'):
print(f'Debug: resume_training:{resume_training}, checkpoint_path: {checkpoint_path}')
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset))

# Load checkpoint if resuming training for gp model.
start_epoch = 0
flag_reload_dataloader = False # We do not need to reset train loader in the new epoch.
ckpt_model_file = os.path.join(checkpoint_path,modelckpt_name)
if resume_training and os.path.exists(ckpt_model_file):
print(f'Debug: loading ckpt: {ckpt_model_file}')
checkpoint = torch.load(ckpt_model_file)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint.get('epoch', 0) # Resume from the same epoch because you did not finished it.

# Update the dataloader if there are segments finished.
finished_seg_names = checkpoint['finished_seg_names']

if len(finished_seg_names) > 0:
# There were segments used in training. Only update the train loader.
flag_reload_dataloader = True
print('Debug: renewing train_loader now...')
startTime_for_tictoc = time.time()
# ---- Dong, 02/15/2024: I want to test training on large dataset and resume training. ----
# train_loader,_,_ = preprocess_data_train_val(data_format, clinical_trial_train, clinical_trial_test, batch_size, finished_seg_names,\
# read_all_labels=False)
train_loader = preprocess_data_test(data_format = data_format, \
clinical_trial_unlabeled=clinical_trial_unlabeled, \
batch_size=batch_size,\
finished_seg_names=finished_seg_names,\
read_all_labels=False)
endTime_for_tictoc = time.time() - startTime_for_tictoc
print(f'Debug: took {endTime_for_tictoc} to renew the train_loader')

best_val_loss = float('inf')
epochs_no_improve = 0

Expand All @@ -86,19 +127,69 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat
'train_loss': [] # Add a list to store training losses
}

for epoch in tqdm(range(num_iterations), desc='Training', unit='epoch', leave=False):
for train_batch in train_loader:
for epoch in tqdm(range(start_epoch,num_iterations), desc='Training', unit='epoch', leave=False):
finished_idx = []
finished_seg_names = []
for batch_index, train_batch in enumerate(train_loader):
print(f'Debug: now in a new batch of data! {batch_index}/{len(train_loader)}') # train_batch is the image data.
model.train()
likelihood.train()
optimizer.zero_grad()

train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device) # Use reshape here
train_y = train_batch['label'].to(device)
# Get finished segment index in dataloader and segment name.
temp_finished_idx = train_batch['idx']
temp_finished_seg_names = train_batch['segment_name']
print('Debug: temp_finished_idx:',temp_finished_idx)
print('Debug: temp_finished_segment_name:',temp_finished_seg_names)
finished_idx.append(temp_finished_idx)
finished_seg_names.append(temp_finished_seg_names)
output = model(train_x)
loss = -mll(output, train_y)
metrics['train_loss'].append(loss.item()) # Store the training loss
loss.backward()
optimizer.step()

save_ckpt_model_path = os.path.join(checkpoint_path,modelckpt_name)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
'finished_seg_names':finished_seg_names,
'finished_idx':finished_idx
# Include other metrics as needed
}, save_ckpt_model_path)

# Optionally, save the dataset state at intervals or after certain conditions
save_ckpt_dataset_path = os.path.join(checkpoint_path,datackpt_name)
train_loader.dataset.save_checkpoint(save_ckpt_dataset_path) # Here, manage the index as needed

# import sys
# if epoch == 3 and batch_index == 5:
# sys.exit(f"Debug: Manually stop the program at epoch {epoch} batch {batch_index}.")

# Reset the finished segments again because we finished one epoch.
finished_idx = []
finished_seg_names = []
if flag_reload_dataloader:
print('Debug: reset the train_loader now...')
# Reset the train dataloader now.
startTime_for_tictoc = time.time()
# --- Dong, 02/15/2024:
# train_loader,_,_ = preprocess_data_train_val(data_format, clinical_trial_train, clinical_trial_test, batch_size, finished_seg_names,\
# read_all_labels=False)
train_loader = preprocess_data_test(data_format = data_format, \
clinical_trial_unlabeled=clinical_trial_unlabeled, \
batch_size=batch_size,\
finished_seg_names=finished_seg_names,\
read_all_labels=False)
endTime_for_tictoc = time.time() - startTime_for_tictoc
print(f'Debug: took {endTime_for_tictoc} to reset the train_loader')
flag_reload_dataloader = False # Turn off the flag for reseting train dataloader.

# Stochastic validation
model.eval()
likelihood.eval()
Expand Down Expand Up @@ -131,19 +222,46 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
torch.save({'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path)
# torch.save({'model_state_dict': model.state_dict(),
# 'likelihood_state_dict': likelihood.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'train_loader':train_loader,
# 'val_loader':val_loader
# }, checkpoint_path)
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print(f"Early stopping triggered at epoch {epoch+1}")
break

# Save checkpoint at the end of each epoch
save_ckpt_model_path = os.path.join(checkpoint_path,modelckpt_name)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
'finished_seg_names':finished_seg_names,
'finished_idx':finished_idx
# Include other metrics as needed
}, save_ckpt_model_path)
print('Debug: saved model checkpoint with epoch.',save_ckpt_model_path)

# Optionally, save the dataset state at intervals or after certain conditions
save_ckpt_dataset_path = os.path.join(checkpoint_path,datackpt_name)
train_loader.dataset.save_checkpoint(save_ckpt_dataset_path) # Finished all batches, so start from zero.

if epochs_no_improve >= patience:
print(f"Early stopping triggered at epoch {epoch+1}")
break

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Optionally, load the best model at the end of training
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

return model, likelihood, metrics

Expand Down
74 changes: 64 additions & 10 deletions BML_project/ss_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
"""
from tqdm import tqdm
import torch
from utils_gp.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from utils_gp.data_loader import preprocess_data_train_val, split_uids, update_train_loader_with_uncertain_samples, preprocess_data_test
from models.ss_gp_model import MultitaskGPModel, train_gp_model
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions
from utils_gp.visualization import plot_comparative_results, plot_training_performance, plot_results
import os
import pickle
from datetime import datetime
now = datetime.now() # Get the time now for model checkpoint saving.

dt_string = now.strftime("%Y_%m_%d_%H_%M_%S") # YYYY_mm_dd_HH_MM_SS, for model saving.
print("The date and time suffix of the model file is", dt_string)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -48,8 +53,29 @@ def main():
clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids()
data_format = 'pt'
# Preprocess data
train_loader, val_loader, test_loader, saving_path = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size)
# ---- Dong, 02/15/2024: I want to test loading large amount dataset. ----
# train_loader, val_loader, saving_path = preprocess_data_train_val(data_format = data_format, \
_, val_loader, saving_path = preprocess_data_train_val(data_format = data_format, \
clinical_trial_train=clinical_trial_train, \
clinical_trial_test=clinical_trial_test, \
batch_size=batch_size,\
finished_seg_names = [],\
read_all_labels=False)
# ---- Dong, 02/15/2024: I want to test loading large amount dataset. ----
# test_loader = preprocess_data_test(data_format = data_format, \
train_loader = preprocess_data_test(data_format = data_format, \
clinical_trial_unlabeled=clinical_trial_unlabeled, \
batch_size=batch_size,\
finished_seg_names=[],\
read_all_labels=False)

menu_segment_names = train_loader.dataset.segment_names # All the segments to be run in the training dataset.
menu_labels = train_loader.dataset.labels # All the ground truth labels
print('Debug: len(menu_segment_names)',len(menu_segment_names))
print('Debug: len(menu_labels)',len(menu_labels))

print('Debug: len(train_loader)',len(train_loader))
print('Debug: dir(train_loader.dataset)',dir(train_loader.dataset))

kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device)

Expand All @@ -61,7 +87,21 @@ def main():
}

# Initial model training
model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes)
model, likelihood, training_metrics = train_gp_model(
train_loader = train_loader,
val_loader = val_loader,
num_iterations=50,
n_classes=n_classes,
patience=10,
checkpoint_path=saving_path,
resume_training=True,
datackpt_name = 'dataset_checkpoint.pt',
modelckpt_name = 'model_checkpoint_full.pt',
batch_size=batch_size,
data_format = data_format,
clinical_trial_train = clinical_trial_train,
clinical_trial_test = clinical_trial_test,
clinical_trial_unlabeled=clinical_trial_unlabeled) # Dong: remember to change this function in its code.

# Save the training metrics for future visualization
results['train_loss'].extend(training_metrics['train_loss'])
Expand All @@ -77,6 +117,7 @@ def main():
# Attempt to load a training checkpoint
train_checkpoint = checkpoint_manager.load_checkpoint('train')
start_iteration = train_checkpoint['iteration'] if train_checkpoint else 0
print('Debug: start_iteration is:',start_iteration)
# Dong, 01/25/2024: save it first before entering the active learning.
additional_state = {
'model_state': model.state_dict(),
Expand All @@ -91,15 +132,20 @@ def main():
active_learning_iterations = 10
# Active Learning Iterations
for iteration in tqdm(range(start_iteration,active_learning_iterations), desc='Active Learning', unit='iteration', leave=True):
print(f"Active Learning Iteration: {iteration+1}/{active_learning_iterations}")
# Perform uncertainty sampling to select new samples from the validation set
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5)

uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=50, n_batches=5, device=device)
labeled_samples = label_samples(uncertain_sample_indices, val_loader.dataset)
# Update the training loader with uncertain samples
train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size)
print(f"Updated training data size: {len(train_loader.dataset)}")
train_loader = update_train_loader_with_uncertain_samples(train_loader, labeled_samples, batch_size)

# Optionally, save the dataset state at intervals or after certain conditions
train_loader.dataset.save_checkpoint(dataset_checkpoint_path) # Here, manage the index as needed

# Re-train the model with the updated training data
model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt')
model, likelihood, val_metrics = train_gp_model(
train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10,
checkpoint_path=saving_path, resume_training=True, batch_size=batch_size)

# Store the validation metrics after each active learning iteration
results['validation_metrics']['precision'].append(val_metrics['precision'])
Expand All @@ -123,13 +169,21 @@ def main():
plot_comparative_results(gp_vs_kmeans_data, original_labels)

# Final evaluation on test set
import subprocess
print('Start to run bash script!')
subprocess.call("./BML_project/untar_unlabeled_PT.sh")
print('End to run bash script!')

test_loader = preprocess_data_test(data_format = data_format, \
clinical_trial_unlabeled=clinical_trial_unlabeled, \
batch_size=batch_size,\
finished_seg_names=[],\
read_all_labels=False)
test_metrics = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes)
test_kmeans_model = run_minibatch_kmeans(test_loader, n_clusters=n_classes, device=device)

results['test_metrics'] = test_metrics
test_gp_vs_kmeans_data, test_original_labels = stochastic_compare_kmeans_gp_predictions(test_kmeans_model, model, test_loader, n_batches=5, device=device)

print(f"Length of original_labels: {len(original_labels)}, Length of gp_predictions: {len(gp_predictions)}")
plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels)

# Visualization of results
Expand Down
Loading

0 comments on commit 861e742

Please sign in to comment.