Skip to content

Commit

Permalink
Checkpoint on dataloader from CustomDataset Class
Browse files Browse the repository at this point in the history
I did changes over the Class CustomDataset. Those changes allows now to get the current batch and other information that perform in a better way our training. Btw the checkpoint main is just a temporal main where I added an example of the usage.

The same is the case with the model checkpoint version, this is an example of usage. I recommend that the best practice is adding this change over the current version, but I cannot evaluate it over the dataset, for that reason I just created this developing version to evaluate it manually.

Co-Authored-By: Dong Han <dong.han@uconn.edu>
Co-Authored-By: Darren Chen <darren.3.chen@uconn.edu>
  • Loading branch information
3 people committed Feb 2, 2024
1 parent c9809a3 commit 528c4be
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 0 deletions.
Empty file.
94 changes: 94 additions & 0 deletions BML_project/main_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
"""
Created on Thu Feb 1 19:43:31 2024
@author: lrm22005
"""
import os
import tqdm
import torch
from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from models.ss_gp_model import MultitaskGPModel, train_gp_model
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions
from utils.visualization import plot_comparative_results, plot_training_performance, plot_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main():
# Set parameters like n_classes, batch_size, etc.
n_classes = 4
batch_size = 1024
clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids()
data_format = 'pt'

# Preprocess data
train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size)

# Attempt to resume from the last saved batch index if a dataset checkpoint exists
dataset_checkpoint_path = 'dataset_checkpoint.pt'
if os.path.exists(dataset_checkpoint_path):
for loader in [train_loader, val_loader, test_loader]:
loader.dataset.load_checkpoint(dataset_checkpoint_path)
print(f"Resuming from batch index {loader.dataset.get_current_batch_index()}")

kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device)

# Initialize result storage
results = {
'train_loss': [],
'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []},
'test_metrics': None
}

# Initial model training
model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes)

# Save the training metrics for future visualization
results['train_loss'].extend(training_metrics['train_loss'])
results['validation_metrics']['precision'].extend(training_metrics['precision'])
results['validation_metrics']['recall'].extend(training_metrics['recall'])
results['validation_metrics']['f1'].extend(training_metrics['f1_score'])

active_learning_iterations = 10
# Active Learning Iterations
for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True):
# Perform uncertainty sampling to select new samples from the validation set
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5)

# Update the training loader with uncertain samples
train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size)

# Optionally, save the dataset state at intervals or after certain conditions
train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None) # Here, manage the index as needed

# Re-train the model with the updated training data
model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt')

# Store the validation metrics after each active learning iteration
results['validation_metrics']['precision'].append(val_metrics['precision'])
results['validation_metrics']['recall'].append(val_metrics['recall'])
results['validation_metrics']['f1'].append(val_metrics['f1'])

# Compare K-Means with GP model predictions after retraining
gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device)

plot_comparative_results(gp_vs_kmeans_data, original_labels)

# Final evaluation on test set
test_metrics = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes)
test_kmeans_model = run_minibatch_kmeans(test_loader, n_clusters=n_classes, device=device)

results['test_metrics'] = test_metrics
test_gp_vs_kmeans_data, test_original_labels = stochastic_compare_kmeans_gp_predictions(test_kmeans_model, model, test_loader, n_batches=5, device=device)
plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels)

# Visualization of results
plot_training_performance(results['train_loss'], results['validation_metrics'])
plot_results(results['test_metrics'])

# Print final test metrics
print("Final Test Metrics:", results['test_metrics'])

if __name__ == "__main__":
main()
218 changes: 218 additions & 0 deletions BML_project/models/ss_gp_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 2 15:14:51 2024
@author: lrm22005
"""
import os
import numpy as np
from tqdm import tqdm
import torch
import gpytorch
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from sklearn.preprocessing import label_binarize

num_latents = 6 # This should match the complexity of your data or the number of tasks
num_tasks = 4 # This should match the number of output classes or tasks
num_inducing_points = 50 # This is independent and should be sufficient for the input space

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MultitaskGPModel(gpytorch.models.ApproximateGP):
def __init__(self):
# Let's use a different set of inducing points for each latent function
inducing_points = torch.rand(num_latents, num_inducing_points, 127 * 128) # Assuming flattened 128x128 images

# We have to mark the CholeskyVariationalDistribution as batch
# so that we learn a variational distribution for each task
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
inducing_points.size(-2), batch_shape=torch.Size([num_latents])
)

# We have to wrap the VariationalStrategy in a LMCVariationalStrategy
# so that the output will be a MultitaskMultivariateNormal rather than a batch output
variational_strategy = gpytorch.variational.LMCVariationalStrategy(
gpytorch.variational.VariationalStrategy(
self, inducing_points, variational_distribution, learn_inducing_locations=True
),
num_tasks=num_tasks,
num_latents=num_latents,
latent_dim=-1
)

super().__init__(variational_strategy)

# The mean and covariance modules should be marked as batch
# so we learn a different set of hyperparameters
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
batch_shape=torch.Size([num_latents])
)

def forward(self, x):
# The forward function should be written as if we were dealing with each output
# dimension in batch
# Ensure x is correctly shaped. It should have the same last dimension size as inducing_points
# x should be reshaped or sliced to have the shape [?, 1] where ? can be any size
# For example, if x originally has shape [N, D], and D != 1, you need to modify x accordingly
# print(f"Input shape: {x.shape}")
# x = x.view(x.size(0), -1) # Flattening the images
# print(f"Input shape after flattening: {x.shape}") # Debugging input shape
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)

# Debugging: Print shapes of intermediate outputs
# print(f"Mean shape: {mean_x.shape}, Covariance shape: {covar_x.shape}")
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
# print(f"Latent prediction shape: {latent_pred.mean.shape}, {latent_pred.covariance_matrix.shape}")

return latent_pred

def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False):
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset))

# Load checkpoint if resuming training
start_epoch = 0
if resume_training and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch
print(f"Resuming training from epoch {start_epoch}")

best_val_loss = float('inf')
epochs_no_improve = 0
metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []}

for epoch in range(start_epoch, num_iterations):
for batch_index, train_batch in enumerate(train_loader):
model.train()
likelihood.train()
optimizer.zero_grad()
train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device)
train_y = train_batch['label'].to(device)
output = model(train_x)
loss = -mll(output, train_y)
metrics['train_loss'].append(loss.item())
loss.backward()
optimizer.step()

# Stochastic validation
model.eval()
likelihood.eval()
with torch.no_grad():
val_indices = torch.randperm(len(val_loader.dataset))[:int(1 * len(val_loader.dataset))]
val_loss = 0.0
val_labels = []
val_predictions = []
for idx in val_indices:
val_batch = val_loader.dataset[idx]
val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device) # Use reshape here
val_y = torch.tensor([val_batch['label']], device=device)
val_output = model(val_x)
val_loss_batch = -mll(val_output, val_y).sum()
val_loss += val_loss_batch.item()
val_labels.append(val_y.item())
val_predictions.append(val_output.mean.argmax(dim=-1).item())

precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro')
# auc_roc = roc_auc_score(label_binarize(val_labels, classes=np.arange(n_classes)),
# label_binarize(val_predictions, classes=np.arange(n_classes)),
# multi_class='ovr')

metrics['precision'].append(precision)
metrics['recall'].append(recall)
metrics['f1_score'].append(f1)
# metrics['auc_roc'].append(auc_roc)
val_loss /= len(val_indices)

if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
torch.save({'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path)
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print(f"Early stopping triggered at epoch {epoch+1}")
break

# Save checkpoint at the end of each epoch
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
# Include other metrics as needed
}, checkpoint_path)

if epochs_no_improve >= patience:
print(f"Early stopping triggered at epoch {epoch+1}")
break

# Optionally, load the best model at the end of training
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

return model, likelihood, metrics

def semi_supervised_labeling(kmeans_model, gp_model, gp_likelihood, data_loader, confidence_threshold=0.8):
gp_model.eval()
gp_likelihood.eval()
labeled_samples = []

with torch.no_grad():
for batch in data_loader:
data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device)
kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy())
gp_predictions = gp_likelihood(gp_model(data_tensor))

# Use GP predictions where the model is confident
confident_indices = gp_predictions.confidence().cpu().numpy() > confidence_threshold
for i, confident in enumerate(confident_indices):
if confident:
labeled_samples.append((data_tensor[i], gp_predictions.mean.argmax(dim=-1)[i].item()))
else:
labeled_samples.append((data_tensor[i], kmeans_predictions[i]))

return labeled_samples

def calculate_elbo(model, likelihood, data_loader):
"""
Calculates the ELBO (Evidence Lower Bound) score for the model on the given data.
Args:
- model: The trained Gaussian Process model.
- likelihood: The likelihood associated with the GP model.
- data_loader: DataLoader providing the data over which to calculate ELBO.
Returns:
- elbo_score: The calculated ELBO score.
"""
model.eval()
likelihood.eval()
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(data_loader.dataset))

with torch.no_grad():
elbo_score = 0.0
for batch in data_loader:
train_x = batch['data'].reshape(batch['data'].size(0), -1).to(device)
train_y = batch['label'].to(device)
output = model(train_x)
# Calculate the ELBO as the negative loss
elbo_score += -mll(output, train_y).sum().item()

# Average the ELBO over all data samples
elbo_score /= len(data_loader.dataset)

return elbo_score
30 changes: 30 additions & 0 deletions BML_project/utils_gp/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='
self.transforms = ToTensor()
self.refresh_dataset()

# Initialize the current batch index to None
self.current_batch_index = None

def refresh_dataset(self):
# Extract unique segment names and their corresponding labels
self.segment_names, self.labels = self.extract_segment_names_and_labels()
Expand All @@ -119,6 +122,25 @@ def add_uids(self, new_uids):

def __len__(self):
return len(self.segment_names)

def save_checkpoint(self, checkpoint_path, current_batch_index=None):
checkpoint = {
'segment_names': self.segment_names,
'labels': self.labels,
'UIDs': self.UIDs,
# Save the current batch index if provided
'current_batch_index': current_batch_index if current_batch_index is not None else self.current_batch_index
}
torch.save(checkpoint, checkpoint_path)

def load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self.segment_names = checkpoint['segment_names']
self.labels = checkpoint['labels']
self.UIDs = checkpoint['UIDs']
self.refresh_dataset()
# Load the current batch index if it exists in the checkpoint
self.current_batch_index = checkpoint.get('current_batch_index')

def __getitem__(self, idx):
segment_name = self.segment_names[idx]

This comment has been minimized.

Copy link
@doh16101

doh16101 Feb 4, 2024

Author Collaborator

Dear @lrm22005 , how does self.current_batch_index interact with the idx in "__getitem__"?

Expand All @@ -133,6 +155,14 @@ def __getitem__(self, idx):

return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name}

# New method to set the current batch index
def set_current_batch_index(self, index):
self.current_batch_index = index

# New method to get the current batch index
def get_current_batch_index(self):
return self.current_batch_index

This comment has been minimized.

Copy link
@doh16101

doh16101 Feb 5, 2024

Author Collaborator

Dear @lrm22005 , as you see from the issue I just posted #18 , the setter and getter of the self.current_batch_index was never called. Do you know how these two functions that were written by us can be run inside the torch.dataset?


def add_data_label_pair(self, data, label):
# Assign a unique ID or name for the new data
new_id = len(self.segment_names)
Expand Down

2 comments on commit 528c4be

@lrm22005
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically both can be idx or index, but these methods are returning different indexes

idx in getitem(self, idx) is the index that is passed to the dataset object when an item is retrieved, such as when you iterate over the dataset or when index it directly dataset[idx]. It represents the index of the specific data point that is being requested.

self.current_batch_index is a class attribute that we can use to keep track of the current position within the dataset during an operation, such as training, where we might want to save and resume the process. It could represent the index of the batch rather than the individual data point, depending on how we use it.

If you want to use self.current_batch_index to resume training from a specific batch, it should be managed such that it corresponds to the batch index, not the index of an individual item within the batch. You can relate with the example in the dataset checkpointing and the examples in the training checkpointing. We're getting the indexes from the method self.current_batch_index. But when we read the loader or it is running, by default it is running over getitem.

@doh16101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically both can be idx or index, but these methods are returning different indexes

idx in getitem(self, idx) is the index that is passed to the dataset object when an item is retrieved, such as when you iterate over the dataset or when index it directly dataset[idx]. It represents the index of the specific data point that is being requested.

self.current_batch_index is a class attribute that we can use to keep track of the current position within the dataset during an operation, such as training, where we might want to save and resume the process. It could represent the index of the batch rather than the individual data point, depending on how we use it.

If you want to use self.current_batch_index to resume training from a specific batch, it should be managed such that it corresponds to the batch index, not the index of an individual item within the batch. You can relate with the example in the dataset checkpointing and the examples in the training checkpointing. We're getting the indexes from the method self.current_batch_index. But when we read the loader or it is running, by default it is running over getitem.

Then it sounds like:

  1. when you resume training, you can only resume from the first segment in this batch, not the exact segment inside the batch that you stopped, because batch index and the segment idx in __getitem__ are two different things;
  2. it seems like idx in __getitem__ can be automatically generated when you know the batch index. But I do not know which function inside Pytorch does that. Is it inside the checkpoint saving function?

I will test your code on Colab and see if it works.

Please sign in to comment.