-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
1,919 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Thu Feb 1 19:43:31 2024 | ||
@author: lrm22005 | ||
""" | ||
import os | ||
import tqdm | ||
import torch | ||
from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples | ||
from models.ss_gp_model import MultitaskGPModel, train_gp_model | ||
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data | ||
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions | ||
from utils.visualization import plot_comparative_results, plot_training_performance, plot_results | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
def main(): | ||
# Set parameters like n_classes, batch_size, etc. | ||
n_classes = 4 | ||
batch_size = 1024 | ||
clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids() | ||
data_format = 'pt' | ||
|
||
# Preprocess data | ||
train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size) | ||
|
||
# Attempt to resume from the last saved batch index if a dataset checkpoint exists | ||
dataset_checkpoint_path = 'dataset_checkpoint.pt' | ||
if os.path.exists(dataset_checkpoint_path): | ||
for loader in [train_loader, val_loader, test_loader]: | ||
loader.dataset.load_checkpoint(dataset_checkpoint_path) | ||
print(f"Resuming from batch index {loader.dataset.get_current_batch_index()}") | ||
|
||
kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device) | ||
|
||
# Initialize result storage | ||
results = { | ||
'train_loss': [], | ||
'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []}, | ||
'test_metrics': None | ||
} | ||
|
||
# Initial model training | ||
model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes) | ||
|
||
# Save the training metrics for future visualization | ||
results['train_loss'].extend(training_metrics['train_loss']) | ||
results['validation_metrics']['precision'].extend(training_metrics['precision']) | ||
results['validation_metrics']['recall'].extend(training_metrics['recall']) | ||
results['validation_metrics']['f1'].extend(training_metrics['f1_score']) | ||
|
||
active_learning_iterations = 10 | ||
# Active Learning Iterations | ||
for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True): | ||
# Perform uncertainty sampling to select new samples from the validation set | ||
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5) | ||
|
||
# Update the training loader with uncertain samples | ||
train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size) | ||
|
||
# Optionally, save the dataset state at intervals or after certain conditions | ||
train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None) # Here, manage the index as needed | ||
|
||
# Re-train the model with the updated training data | ||
model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt') | ||
|
||
# Store the validation metrics after each active learning iteration | ||
results['validation_metrics']['precision'].append(val_metrics['precision']) | ||
results['validation_metrics']['recall'].append(val_metrics['recall']) | ||
results['validation_metrics']['f1'].append(val_metrics['f1']) | ||
|
||
# Compare K-Means with GP model predictions after retraining | ||
gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device) | ||
|
||
plot_comparative_results(gp_vs_kmeans_data, original_labels) | ||
|
||
# Final evaluation on test set | ||
test_metrics = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes) | ||
test_kmeans_model = run_minibatch_kmeans(test_loader, n_clusters=n_classes, device=device) | ||
|
||
results['test_metrics'] = test_metrics | ||
test_gp_vs_kmeans_data, test_original_labels = stochastic_compare_kmeans_gp_predictions(test_kmeans_model, model, test_loader, n_batches=5, device=device) | ||
plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels) | ||
|
||
# Visualization of results | ||
plot_training_performance(results['train_loss'], results['validation_metrics']) | ||
plot_results(results['test_metrics']) | ||
|
||
# Print final test metrics | ||
print("Final Test Metrics:", results['test_metrics']) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Fri Feb 2 15:14:51 2024 | ||
@author: lrm22005 | ||
""" | ||
import os | ||
import numpy as np | ||
from tqdm import tqdm | ||
import torch | ||
import gpytorch | ||
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score | ||
from sklearn.preprocessing import label_binarize | ||
|
||
num_latents = 6 # This should match the complexity of your data or the number of tasks | ||
num_tasks = 4 # This should match the number of output classes or tasks | ||
num_inducing_points = 50 # This is independent and should be sufficient for the input space | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
class MultitaskGPModel(gpytorch.models.ApproximateGP): | ||
def __init__(self): | ||
# Let's use a different set of inducing points for each latent function | ||
inducing_points = torch.rand(num_latents, num_inducing_points, 127 * 128) # Assuming flattened 128x128 images | ||
|
||
# We have to mark the CholeskyVariationalDistribution as batch | ||
# so that we learn a variational distribution for each task | ||
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution( | ||
inducing_points.size(-2), batch_shape=torch.Size([num_latents]) | ||
) | ||
|
||
# We have to wrap the VariationalStrategy in a LMCVariationalStrategy | ||
# so that the output will be a MultitaskMultivariateNormal rather than a batch output | ||
variational_strategy = gpytorch.variational.LMCVariationalStrategy( | ||
gpytorch.variational.VariationalStrategy( | ||
self, inducing_points, variational_distribution, learn_inducing_locations=True | ||
), | ||
num_tasks=num_tasks, | ||
num_latents=num_latents, | ||
latent_dim=-1 | ||
) | ||
|
||
super().__init__(variational_strategy) | ||
|
||
# The mean and covariance modules should be marked as batch | ||
# so we learn a different set of hyperparameters | ||
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents])) | ||
self.covar_module = gpytorch.kernels.ScaleKernel( | ||
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])), | ||
batch_shape=torch.Size([num_latents]) | ||
) | ||
|
||
def forward(self, x): | ||
# The forward function should be written as if we were dealing with each output | ||
# dimension in batch | ||
# Ensure x is correctly shaped. It should have the same last dimension size as inducing_points | ||
# x should be reshaped or sliced to have the shape [?, 1] where ? can be any size | ||
# For example, if x originally has shape [N, D], and D != 1, you need to modify x accordingly | ||
# print(f"Input shape: {x.shape}") | ||
# x = x.view(x.size(0), -1) # Flattening the images | ||
# print(f"Input shape after flattening: {x.shape}") # Debugging input shape | ||
mean_x = self.mean_module(x) | ||
covar_x = self.covar_module(x) | ||
|
||
# Debugging: Print shapes of intermediate outputs | ||
# print(f"Mean shape: {mean_x.shape}, Covariance shape: {covar_x.shape}") | ||
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) | ||
# print(f"Latent prediction shape: {latent_pred.mean.shape}, {latent_pred.covariance_matrix.shape}") | ||
|
||
return latent_pred | ||
|
||
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False): | ||
model = MultitaskGPModel().to(device) | ||
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) | ||
|
||
# Load checkpoint if resuming training | ||
start_epoch = 0 | ||
if resume_training and os.path.exists(checkpoint_path): | ||
checkpoint = torch.load(checkpoint_path) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch | ||
print(f"Resuming training from epoch {start_epoch}") | ||
|
||
best_val_loss = float('inf') | ||
epochs_no_improve = 0 | ||
metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []} | ||
|
||
for epoch in range(start_epoch, num_iterations): | ||
for batch_index, train_batch in enumerate(train_loader): | ||
model.train() | ||
likelihood.train() | ||
optimizer.zero_grad() | ||
train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device) | ||
train_y = train_batch['label'].to(device) | ||
output = model(train_x) | ||
loss = -mll(output, train_y) | ||
metrics['train_loss'].append(loss.item()) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# Stochastic validation | ||
model.eval() | ||
likelihood.eval() | ||
with torch.no_grad(): | ||
val_indices = torch.randperm(len(val_loader.dataset))[:int(1 * len(val_loader.dataset))] | ||
val_loss = 0.0 | ||
val_labels = [] | ||
val_predictions = [] | ||
for idx in val_indices: | ||
val_batch = val_loader.dataset[idx] | ||
val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device) # Use reshape here | ||
val_y = torch.tensor([val_batch['label']], device=device) | ||
val_output = model(val_x) | ||
val_loss_batch = -mll(val_output, val_y).sum() | ||
val_loss += val_loss_batch.item() | ||
val_labels.append(val_y.item()) | ||
val_predictions.append(val_output.mean.argmax(dim=-1).item()) | ||
|
||
precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro') | ||
# auc_roc = roc_auc_score(label_binarize(val_labels, classes=np.arange(n_classes)), | ||
# label_binarize(val_predictions, classes=np.arange(n_classes)), | ||
# multi_class='ovr') | ||
|
||
metrics['precision'].append(precision) | ||
metrics['recall'].append(recall) | ||
metrics['f1_score'].append(f1) | ||
# metrics['auc_roc'].append(auc_roc) | ||
val_loss /= len(val_indices) | ||
|
||
if val_loss < best_val_loss: | ||
best_val_loss = val_loss | ||
epochs_no_improve = 0 | ||
torch.save({'model_state_dict': model.state_dict(), | ||
'likelihood_state_dict': likelihood.state_dict(), | ||
'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path) | ||
else: | ||
epochs_no_improve += 1 | ||
if epochs_no_improve >= patience: | ||
print(f"Early stopping triggered at epoch {epoch+1}") | ||
break | ||
|
||
# Save checkpoint at the end of each epoch | ||
torch.save({ | ||
'epoch': epoch, | ||
'model_state_dict': model.state_dict(), | ||
'likelihood_state_dict': likelihood.state_dict(), | ||
'optimizer_state_dict': optimizer.state_dict(), | ||
'best_val_loss': best_val_loss, | ||
# Include other metrics as needed | ||
}, checkpoint_path) | ||
|
||
if epochs_no_improve >= patience: | ||
print(f"Early stopping triggered at epoch {epoch+1}") | ||
break | ||
|
||
# Optionally, load the best model at the end of training | ||
if os.path.exists(checkpoint_path): | ||
checkpoint = torch.load(checkpoint_path) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
|
||
return model, likelihood, metrics | ||
|
||
def semi_supervised_labeling(kmeans_model, gp_model, gp_likelihood, data_loader, confidence_threshold=0.8): | ||
gp_model.eval() | ||
gp_likelihood.eval() | ||
labeled_samples = [] | ||
|
||
with torch.no_grad(): | ||
for batch in data_loader: | ||
data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device) | ||
kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy()) | ||
gp_predictions = gp_likelihood(gp_model(data_tensor)) | ||
|
||
# Use GP predictions where the model is confident | ||
confident_indices = gp_predictions.confidence().cpu().numpy() > confidence_threshold | ||
for i, confident in enumerate(confident_indices): | ||
if confident: | ||
labeled_samples.append((data_tensor[i], gp_predictions.mean.argmax(dim=-1)[i].item())) | ||
else: | ||
labeled_samples.append((data_tensor[i], kmeans_predictions[i])) | ||
|
||
return labeled_samples | ||
|
||
def calculate_elbo(model, likelihood, data_loader): | ||
""" | ||
Calculates the ELBO (Evidence Lower Bound) score for the model on the given data. | ||
Args: | ||
- model: The trained Gaussian Process model. | ||
- likelihood: The likelihood associated with the GP model. | ||
- data_loader: DataLoader providing the data over which to calculate ELBO. | ||
Returns: | ||
- elbo_score: The calculated ELBO score. | ||
""" | ||
model.eval() | ||
likelihood.eval() | ||
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(data_loader.dataset)) | ||
|
||
with torch.no_grad(): | ||
elbo_score = 0.0 | ||
for batch in data_loader: | ||
train_x = batch['data'].reshape(batch['data'].size(0), -1).to(device) | ||
train_y = batch['label'].to(device) | ||
output = model(train_x) | ||
# Calculate the ELBO as the negative loss | ||
elbo_score += -mll(output, train_y).sum().item() | ||
|
||
# Average the ELBO over all data samples | ||
elbo_score /= len(data_loader.dataset) | ||
|
||
return elbo_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.