Skip to content

Commit

Permalink
Merge branch 'main' into Cassey
Browse files Browse the repository at this point in the history
  • Loading branch information
doh16101 committed Feb 4, 2024
2 parents 781947f + 528c4be commit 928e7fe
Show file tree
Hide file tree
Showing 15 changed files with 1,919 additions and 0 deletions.
Empty file.
94 changes: 94 additions & 0 deletions BML_project/main_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
"""
Created on Thu Feb 1 19:43:31 2024
@author: lrm22005
"""
import os
import tqdm
import torch
from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from models.ss_gp_model import MultitaskGPModel, train_gp_model
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions
from utils.visualization import plot_comparative_results, plot_training_performance, plot_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main():
# Set parameters like n_classes, batch_size, etc.
n_classes = 4
batch_size = 1024
clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids()
data_format = 'pt'

# Preprocess data
train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size)

# Attempt to resume from the last saved batch index if a dataset checkpoint exists
dataset_checkpoint_path = 'dataset_checkpoint.pt'
if os.path.exists(dataset_checkpoint_path):
for loader in [train_loader, val_loader, test_loader]:
loader.dataset.load_checkpoint(dataset_checkpoint_path)
print(f"Resuming from batch index {loader.dataset.get_current_batch_index()}")

kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device)

# Initialize result storage
results = {
'train_loss': [],
'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []},
'test_metrics': None
}

# Initial model training
model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes)

# Save the training metrics for future visualization
results['train_loss'].extend(training_metrics['train_loss'])
results['validation_metrics']['precision'].extend(training_metrics['precision'])
results['validation_metrics']['recall'].extend(training_metrics['recall'])
results['validation_metrics']['f1'].extend(training_metrics['f1_score'])

active_learning_iterations = 10
# Active Learning Iterations
for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True):
# Perform uncertainty sampling to select new samples from the validation set
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5)

# Update the training loader with uncertain samples
train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size)

# Optionally, save the dataset state at intervals or after certain conditions
train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None) # Here, manage the index as needed

# Re-train the model with the updated training data
model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt')

# Store the validation metrics after each active learning iteration
results['validation_metrics']['precision'].append(val_metrics['precision'])
results['validation_metrics']['recall'].append(val_metrics['recall'])
results['validation_metrics']['f1'].append(val_metrics['f1'])

# Compare K-Means with GP model predictions after retraining
gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device)

plot_comparative_results(gp_vs_kmeans_data, original_labels)

# Final evaluation on test set
test_metrics = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes)
test_kmeans_model = run_minibatch_kmeans(test_loader, n_clusters=n_classes, device=device)

results['test_metrics'] = test_metrics
test_gp_vs_kmeans_data, test_original_labels = stochastic_compare_kmeans_gp_predictions(test_kmeans_model, model, test_loader, n_batches=5, device=device)
plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels)

# Visualization of results
plot_training_performance(results['train_loss'], results['validation_metrics'])
plot_results(results['test_metrics'])

# Print final test metrics
print("Final Test Metrics:", results['test_metrics'])

if __name__ == "__main__":
main()
218 changes: 218 additions & 0 deletions BML_project/models/ss_gp_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 2 15:14:51 2024
@author: lrm22005
"""
import os
import numpy as np
from tqdm import tqdm
import torch
import gpytorch
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from sklearn.preprocessing import label_binarize

num_latents = 6 # This should match the complexity of your data or the number of tasks
num_tasks = 4 # This should match the number of output classes or tasks
num_inducing_points = 50 # This is independent and should be sufficient for the input space

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MultitaskGPModel(gpytorch.models.ApproximateGP):
def __init__(self):
# Let's use a different set of inducing points for each latent function
inducing_points = torch.rand(num_latents, num_inducing_points, 127 * 128) # Assuming flattened 128x128 images

# We have to mark the CholeskyVariationalDistribution as batch
# so that we learn a variational distribution for each task
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
inducing_points.size(-2), batch_shape=torch.Size([num_latents])
)

# We have to wrap the VariationalStrategy in a LMCVariationalStrategy
# so that the output will be a MultitaskMultivariateNormal rather than a batch output
variational_strategy = gpytorch.variational.LMCVariationalStrategy(
gpytorch.variational.VariationalStrategy(
self, inducing_points, variational_distribution, learn_inducing_locations=True
),
num_tasks=num_tasks,
num_latents=num_latents,
latent_dim=-1
)

super().__init__(variational_strategy)

# The mean and covariance modules should be marked as batch
# so we learn a different set of hyperparameters
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
batch_shape=torch.Size([num_latents])
)

def forward(self, x):
# The forward function should be written as if we were dealing with each output
# dimension in batch
# Ensure x is correctly shaped. It should have the same last dimension size as inducing_points
# x should be reshaped or sliced to have the shape [?, 1] where ? can be any size
# For example, if x originally has shape [N, D], and D != 1, you need to modify x accordingly
# print(f"Input shape: {x.shape}")
# x = x.view(x.size(0), -1) # Flattening the images
# print(f"Input shape after flattening: {x.shape}") # Debugging input shape
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)

# Debugging: Print shapes of intermediate outputs
# print(f"Mean shape: {mean_x.shape}, Covariance shape: {covar_x.shape}")
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
# print(f"Latent prediction shape: {latent_pred.mean.shape}, {latent_pred.covariance_matrix.shape}")

return latent_pred

def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False):
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset))

# Load checkpoint if resuming training
start_epoch = 0
if resume_training and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch
print(f"Resuming training from epoch {start_epoch}")

best_val_loss = float('inf')
epochs_no_improve = 0
metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []}

for epoch in range(start_epoch, num_iterations):
for batch_index, train_batch in enumerate(train_loader):
model.train()
likelihood.train()
optimizer.zero_grad()
train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device)
train_y = train_batch['label'].to(device)
output = model(train_x)
loss = -mll(output, train_y)
metrics['train_loss'].append(loss.item())
loss.backward()
optimizer.step()

# Stochastic validation
model.eval()
likelihood.eval()
with torch.no_grad():
val_indices = torch.randperm(len(val_loader.dataset))[:int(1 * len(val_loader.dataset))]
val_loss = 0.0
val_labels = []
val_predictions = []
for idx in val_indices:
val_batch = val_loader.dataset[idx]
val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device) # Use reshape here
val_y = torch.tensor([val_batch['label']], device=device)
val_output = model(val_x)
val_loss_batch = -mll(val_output, val_y).sum()
val_loss += val_loss_batch.item()
val_labels.append(val_y.item())
val_predictions.append(val_output.mean.argmax(dim=-1).item())

precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro')
# auc_roc = roc_auc_score(label_binarize(val_labels, classes=np.arange(n_classes)),
# label_binarize(val_predictions, classes=np.arange(n_classes)),
# multi_class='ovr')

metrics['precision'].append(precision)
metrics['recall'].append(recall)
metrics['f1_score'].append(f1)
# metrics['auc_roc'].append(auc_roc)
val_loss /= len(val_indices)

if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
torch.save({'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path)
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print(f"Early stopping triggered at epoch {epoch+1}")
break

# Save checkpoint at the end of each epoch
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
# Include other metrics as needed
}, checkpoint_path)

if epochs_no_improve >= patience:
print(f"Early stopping triggered at epoch {epoch+1}")
break

# Optionally, load the best model at the end of training
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

return model, likelihood, metrics

def semi_supervised_labeling(kmeans_model, gp_model, gp_likelihood, data_loader, confidence_threshold=0.8):
gp_model.eval()
gp_likelihood.eval()
labeled_samples = []

with torch.no_grad():
for batch in data_loader:
data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device)
kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy())
gp_predictions = gp_likelihood(gp_model(data_tensor))

# Use GP predictions where the model is confident
confident_indices = gp_predictions.confidence().cpu().numpy() > confidence_threshold
for i, confident in enumerate(confident_indices):
if confident:
labeled_samples.append((data_tensor[i], gp_predictions.mean.argmax(dim=-1)[i].item()))
else:
labeled_samples.append((data_tensor[i], kmeans_predictions[i]))

return labeled_samples

def calculate_elbo(model, likelihood, data_loader):
"""
Calculates the ELBO (Evidence Lower Bound) score for the model on the given data.
Args:
- model: The trained Gaussian Process model.
- likelihood: The likelihood associated with the GP model.
- data_loader: DataLoader providing the data over which to calculate ELBO.
Returns:
- elbo_score: The calculated ELBO score.
"""
model.eval()
likelihood.eval()
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(data_loader.dataset))

with torch.no_grad():
elbo_score = 0.0
for batch in data_loader:
train_x = batch['data'].reshape(batch['data'].size(0), -1).to(device)
train_y = batch['label'].to(device)
output = model(train_x)
# Calculate the ELBO as the negative loss
elbo_score += -mll(output, train_y).sum().item()

# Average the ELBO over all data samples
elbo_score /= len(data_loader.dataset)

return elbo_score
30 changes: 30 additions & 0 deletions BML_project/utils_gp/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='
self.transforms = ToTensor()
self.refresh_dataset()

# Initialize the current batch index to None
self.current_batch_index = None

def refresh_dataset(self):
# Extract unique segment names and their corresponding labels
self.segment_names, self.labels = self.extract_segment_names_and_labels()
Expand All @@ -129,6 +132,25 @@ def add_uids(self, new_uids):

def __len__(self):
return len(self.segment_names)

def save_checkpoint(self, checkpoint_path, current_batch_index=None):
checkpoint = {
'segment_names': self.segment_names,
'labels': self.labels,
'UIDs': self.UIDs,
# Save the current batch index if provided
'current_batch_index': current_batch_index if current_batch_index is not None else self.current_batch_index
}
torch.save(checkpoint, checkpoint_path)

def load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self.segment_names = checkpoint['segment_names']
self.labels = checkpoint['labels']
self.UIDs = checkpoint['UIDs']
self.refresh_dataset()
# Load the current batch index if it exists in the checkpoint
self.current_batch_index = checkpoint.get('current_batch_index')

def __getitem__(self, idx):
segment_name = self.segment_names[idx]
Expand All @@ -142,6 +164,14 @@ def __getitem__(self, idx):
time_freq_tensor = self.load_data(segment_name)
return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name}

# New method to set the current batch index
def set_current_batch_index(self, index):
self.current_batch_index = index

# New method to get the current batch index
def get_current_batch_index(self):
return self.current_batch_index

def add_data_label_pair(self, data, label):
# Assign a unique ID or name for the new data
new_id = len(self.segment_names)
Expand Down
Loading

0 comments on commit 928e7fe

Please sign in to comment.