-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
I did changes over the Class CustomDataset. Those changes allows now to get the current batch and other information that perform in a better way our training. Btw the checkpoint main is just a temporal main where I added an example of the usage. The same is the case with the model checkpoint version, this is an example of usage. I recommend that the best practice is adding this change over the current version, but I cannot evaluate it over the dataset, for that reason I just created this developing version to evaluate it manually. Co-Authored-By: Dong Han <dong.han@uconn.edu> Co-Authored-By: Darren Chen <darren.3.chen@uconn.edu>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Thu Feb 1 19:43:31 2024 | ||
@author: lrm22005 | ||
""" | ||
import os | ||
import tqdm | ||
import torch | ||
from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples | ||
from models.ss_gp_model import MultitaskGPModel, train_gp_model | ||
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data | ||
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions | ||
from utils.visualization import plot_comparative_results, plot_training_performance, plot_results | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
def main(): | ||
# Set parameters like n_classes, batch_size, etc. | ||
n_classes = 4 | ||
batch_size = 1024 | ||
clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids() | ||
data_format = 'pt' | ||
|
||
# Preprocess data | ||
train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size) | ||
|
||
# Attempt to resume from the last saved batch index if a dataset checkpoint exists | ||
dataset_checkpoint_path = 'dataset_checkpoint.pt' | ||
if os.path.exists(dataset_checkpoint_path): | ||
for loader in [train_loader, val_loader, test_loader]: | ||
loader.dataset.load_checkpoint(dataset_checkpoint_path) | ||
print(f"Resuming from batch index {loader.dataset.get_current_batch_index()}") | ||
|
||
kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device) | ||
|
||
# Initialize result storage | ||
results = { | ||
'train_loss': [], | ||
'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []}, | ||
'test_metrics': None | ||
} | ||
|
||
# Initial model training | ||
model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes) | ||
|
||
# Save the training metrics for future visualization | ||
results['train_loss'].extend(training_metrics['train_loss']) | ||
results['validation_metrics']['precision'].extend(training_metrics['precision']) | ||
results['validation_metrics']['recall'].extend(training_metrics['recall']) | ||
results['validation_metrics']['f1'].extend(training_metrics['f1_score']) | ||
|
||
active_learning_iterations = 10 | ||
# Active Learning Iterations | ||
for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True): | ||
# Perform uncertainty sampling to select new samples from the validation set | ||
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5) | ||
|
||
# Update the training loader with uncertain samples | ||
train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size) | ||
|
||
# Optionally, save the dataset state at intervals or after certain conditions | ||
train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None) # Here, manage the index as needed | ||
|
||
# Re-train the model with the updated training data | ||
model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt') | ||
|
||
# Store the validation metrics after each active learning iteration | ||
results['validation_metrics']['precision'].append(val_metrics['precision']) | ||
results['validation_metrics']['recall'].append(val_metrics['recall']) | ||
results['validation_metrics']['f1'].append(val_metrics['f1']) | ||
|
||
# Compare K-Means with GP model predictions after retraining | ||
gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device) | ||
|
||
plot_comparative_results(gp_vs_kmeans_data, original_labels) | ||
|
||
# Final evaluation on test set | ||
test_metrics = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes) | ||
test_kmeans_model = run_minibatch_kmeans(test_loader, n_clusters=n_classes, device=device) | ||
|
||
results['test_metrics'] = test_metrics | ||
test_gp_vs_kmeans_data, test_original_labels = stochastic_compare_kmeans_gp_predictions(test_kmeans_model, model, test_loader, n_batches=5, device=device) | ||
plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels) | ||
|
||
# Visualization of results | ||
plot_training_performance(results['train_loss'], results['validation_metrics']) | ||
plot_results(results['test_metrics']) | ||
|
||
# Print final test metrics | ||
print("Final Test Metrics:", results['test_metrics']) | ||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Fri Feb 2 15:14:51 2024 | ||
@author: lrm22005 | ||
""" | ||
import os | ||
import numpy as np | ||
from tqdm import tqdm | ||
import torch | ||
import gpytorch | ||
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score | ||
from sklearn.preprocessing import label_binarize | ||
|
||
num_latents = 6 # This should match the complexity of your data or the number of tasks | ||
num_tasks = 4 # This should match the number of output classes or tasks | ||
num_inducing_points = 50 # This is independent and should be sufficient for the input space | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
class MultitaskGPModel(gpytorch.models.ApproximateGP): | ||
def __init__(self): | ||
# Let's use a different set of inducing points for each latent function | ||
inducing_points = torch.rand(num_latents, num_inducing_points, 127 * 128) # Assuming flattened 128x128 images | ||
|
||
# We have to mark the CholeskyVariationalDistribution as batch | ||
# so that we learn a variational distribution for each task | ||
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution( | ||
inducing_points.size(-2), batch_shape=torch.Size([num_latents]) | ||
) | ||
|
||
# We have to wrap the VariationalStrategy in a LMCVariationalStrategy | ||
# so that the output will be a MultitaskMultivariateNormal rather than a batch output | ||
variational_strategy = gpytorch.variational.LMCVariationalStrategy( | ||
gpytorch.variational.VariationalStrategy( | ||
self, inducing_points, variational_distribution, learn_inducing_locations=True | ||
), | ||
num_tasks=num_tasks, | ||
num_latents=num_latents, | ||
latent_dim=-1 | ||
) | ||
|
||
super().__init__(variational_strategy) | ||
|
||
# The mean and covariance modules should be marked as batch | ||
# so we learn a different set of hyperparameters | ||
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents])) | ||
self.covar_module = gpytorch.kernels.ScaleKernel( | ||
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])), | ||
batch_shape=torch.Size([num_latents]) | ||
) | ||
|
||
def forward(self, x): | ||
# The forward function should be written as if we were dealing with each output | ||
# dimension in batch | ||
# Ensure x is correctly shaped. It should have the same last dimension size as inducing_points | ||
# x should be reshaped or sliced to have the shape [?, 1] where ? can be any size | ||
# For example, if x originally has shape [N, D], and D != 1, you need to modify x accordingly | ||
# print(f"Input shape: {x.shape}") | ||
# x = x.view(x.size(0), -1) # Flattening the images | ||
# print(f"Input shape after flattening: {x.shape}") # Debugging input shape | ||
mean_x = self.mean_module(x) | ||
covar_x = self.covar_module(x) | ||
|
||
# Debugging: Print shapes of intermediate outputs | ||
# print(f"Mean shape: {mean_x.shape}, Covariance shape: {covar_x.shape}") | ||
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) | ||
# print(f"Latent prediction shape: {latent_pred.mean.shape}, {latent_pred.covariance_matrix.shape}") | ||
|
||
return latent_pred | ||
|
||
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False): | ||
model = MultitaskGPModel().to(device) | ||
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) | ||
|
||
# Load checkpoint if resuming training | ||
start_epoch = 0 | ||
if resume_training and os.path.exists(checkpoint_path): | ||
checkpoint = torch.load(checkpoint_path) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch | ||
print(f"Resuming training from epoch {start_epoch}") | ||
|
||
best_val_loss = float('inf') | ||
epochs_no_improve = 0 | ||
metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []} | ||
|
||
for epoch in range(start_epoch, num_iterations): | ||
for batch_index, train_batch in enumerate(train_loader): | ||
model.train() | ||
likelihood.train() | ||
optimizer.zero_grad() | ||
train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device) | ||
train_y = train_batch['label'].to(device) | ||
output = model(train_x) | ||
loss = -mll(output, train_y) | ||
metrics['train_loss'].append(loss.item()) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# Stochastic validation | ||
model.eval() | ||
likelihood.eval() | ||
with torch.no_grad(): | ||
val_indices = torch.randperm(len(val_loader.dataset))[:int(1 * len(val_loader.dataset))] | ||
val_loss = 0.0 | ||
val_labels = [] | ||
val_predictions = [] | ||
for idx in val_indices: | ||
val_batch = val_loader.dataset[idx] | ||
val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device) # Use reshape here | ||
val_y = torch.tensor([val_batch['label']], device=device) | ||
val_output = model(val_x) | ||
val_loss_batch = -mll(val_output, val_y).sum() | ||
val_loss += val_loss_batch.item() | ||
val_labels.append(val_y.item()) | ||
val_predictions.append(val_output.mean.argmax(dim=-1).item()) | ||
|
||
precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro') | ||
# auc_roc = roc_auc_score(label_binarize(val_labels, classes=np.arange(n_classes)), | ||
# label_binarize(val_predictions, classes=np.arange(n_classes)), | ||
# multi_class='ovr') | ||
|
||
metrics['precision'].append(precision) | ||
metrics['recall'].append(recall) | ||
metrics['f1_score'].append(f1) | ||
# metrics['auc_roc'].append(auc_roc) | ||
val_loss /= len(val_indices) | ||
|
||
if val_loss < best_val_loss: | ||
best_val_loss = val_loss | ||
epochs_no_improve = 0 | ||
torch.save({'model_state_dict': model.state_dict(), | ||
'likelihood_state_dict': likelihood.state_dict(), | ||
'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path) | ||
else: | ||
epochs_no_improve += 1 | ||
if epochs_no_improve >= patience: | ||
print(f"Early stopping triggered at epoch {epoch+1}") | ||
break | ||
|
||
# Save checkpoint at the end of each epoch | ||
torch.save({ | ||
'epoch': epoch, | ||
'model_state_dict': model.state_dict(), | ||
'likelihood_state_dict': likelihood.state_dict(), | ||
'optimizer_state_dict': optimizer.state_dict(), | ||
'best_val_loss': best_val_loss, | ||
# Include other metrics as needed | ||
}, checkpoint_path) | ||
|
||
if epochs_no_improve >= patience: | ||
print(f"Early stopping triggered at epoch {epoch+1}") | ||
break | ||
|
||
# Optionally, load the best model at the end of training | ||
if os.path.exists(checkpoint_path): | ||
checkpoint = torch.load(checkpoint_path) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
|
||
return model, likelihood, metrics | ||
|
||
def semi_supervised_labeling(kmeans_model, gp_model, gp_likelihood, data_loader, confidence_threshold=0.8): | ||
gp_model.eval() | ||
gp_likelihood.eval() | ||
labeled_samples = [] | ||
|
||
with torch.no_grad(): | ||
for batch in data_loader: | ||
data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device) | ||
kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy()) | ||
gp_predictions = gp_likelihood(gp_model(data_tensor)) | ||
|
||
# Use GP predictions where the model is confident | ||
confident_indices = gp_predictions.confidence().cpu().numpy() > confidence_threshold | ||
for i, confident in enumerate(confident_indices): | ||
if confident: | ||
labeled_samples.append((data_tensor[i], gp_predictions.mean.argmax(dim=-1)[i].item())) | ||
else: | ||
labeled_samples.append((data_tensor[i], kmeans_predictions[i])) | ||
|
||
return labeled_samples | ||
|
||
def calculate_elbo(model, likelihood, data_loader): | ||
""" | ||
Calculates the ELBO (Evidence Lower Bound) score for the model on the given data. | ||
Args: | ||
- model: The trained Gaussian Process model. | ||
- likelihood: The likelihood associated with the GP model. | ||
- data_loader: DataLoader providing the data over which to calculate ELBO. | ||
Returns: | ||
- elbo_score: The calculated ELBO score. | ||
""" | ||
model.eval() | ||
likelihood.eval() | ||
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(data_loader.dataset)) | ||
|
||
with torch.no_grad(): | ||
elbo_score = 0.0 | ||
for batch in data_loader: | ||
train_x = batch['data'].reshape(batch['data'].size(0), -1).to(device) | ||
train_y = batch['label'].to(device) | ||
output = model(train_x) | ||
# Calculate the ELBO as the negative loss | ||
elbo_score += -mll(output, train_y).sum().item() | ||
|
||
# Average the ELBO over all data samples | ||
elbo_score /= len(data_loader.dataset) | ||
|
||
return elbo_score |
2 comments
on commit 528c4be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically both can be idx or index, but these methods are returning different indexes
idx in getitem(self, idx) is the index that is passed to the dataset object when an item is retrieved, such as when you iterate over the dataset or when index it directly dataset[idx]. It represents the index of the specific data point that is being requested.
self.current_batch_index is a class attribute that we can use to keep track of the current position within the dataset during an operation, such as training, where we might want to save and resume the process. It could represent the index of the batch rather than the individual data point, depending on how we use it.
If you want to use self.current_batch_index to resume training from a specific batch, it should be managed such that it corresponds to the batch index, not the index of an individual item within the batch. You can relate with the example in the dataset checkpointing and the examples in the training checkpointing. We're getting the indexes from the method self.current_batch_index. But when we read the loader or it is running, by default it is running over getitem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically both can be idx or index, but these methods are returning different indexes
idx in getitem(self, idx) is the index that is passed to the dataset object when an item is retrieved, such as when you iterate over the dataset or when index it directly dataset[idx]. It represents the index of the specific data point that is being requested.
self.current_batch_index is a class attribute that we can use to keep track of the current position within the dataset during an operation, such as training, where we might want to save and resume the process. It could represent the index of the batch rather than the individual data point, depending on how we use it.
If you want to use self.current_batch_index to resume training from a specific batch, it should be managed such that it corresponds to the batch index, not the index of an individual item within the batch. You can relate with the example in the dataset checkpointing and the examples in the training checkpointing. We're getting the indexes from the method self.current_batch_index. But when we read the loader or it is running, by default it is running over getitem.
Then it sounds like:
- when you resume training, you can only resume from the first segment in this batch, not the exact segment inside the batch that you stopped, because batch index and the segment
idx
in__getitem__
are two different things; - it seems like
idx
in__getitem__
can be automatically generated when you know the batch index. But I do not know which function inside Pytorch does that. Is it inside the checkpoint saving function?
I will test your code on Colab and see if it works.
Dear @lrm22005 , how does
self.current_batch_index
interact with theidx
in "__getitem__
"?