Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 18 12:52:53 2024
@author: lrmercadod
"""
import torch
import torch.nn as nn
import time
import datetime as dt
import gpytorch
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from sklearn.preprocessing import label_binarize
# Import my own functions and classes
from utils.pathmaster import PathMaster
from utils.dataloader import preprocess_data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_latents = 6 # This should match the complexity of your data or the number of tasks
num_tasks = 4 # This should match the number of output classes or tasks
num_inducing_points = 50 # This is independent and should be sufficient for the input space
class MultitaskGPModel(gpytorch.models.ApproximateGP):
def __init__(self):
# Let's use a different set of inducing points for each latent function
inducing_points = torch.rand(num_latents, num_inducing_points, 128 * 128) # Assuming flattened 128x128 images
# We have to mark the CholeskyVariationalDistribution as batch
# so that we learn a variational distribution for each task
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
inducing_points.size(-2), batch_shape=torch.Size([num_latents])
)
# We have to wrap the VariationalStrategy in a LMCVariationalStrategy
# so that the output will be a MultitaskMultivariateNormal rather than a batch output
variational_strategy = gpytorch.variational.LMCVariationalStrategy(
gpytorch.variational.VariationalStrategy(
self, inducing_points, variational_distribution, learn_inducing_locations=True
),
num_tasks=num_tasks,
num_latents=num_latents,
latent_dim=-1
)
super().__init__(variational_strategy)
# The mean and covariance modules should be marked as batch
# so we learn a different set of hyperparameters
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
batch_shape=torch.Size([num_latents])
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
return latent_pred
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10,
checkpoint_path='model_checkpoint.pt', resume_training=False):
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset))
start_epoch = 0
if resume_training and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint.get('epoch', 0)
best_val_loss = float('inf')
epochs_no_improve = 0
metrics = {
'precision': [],
'recall': [],
'f1_score': [],
'auc_roc': [],
'train_loss': []
}
for epoch in range(start_epoch, num_iterations):
model.train()
likelihood.train()
for train_batch in train_loader:
optimizer.zero_grad()
train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device)
train_y = train_batch['label'].to(device)
output = model(train_x)
loss = -mll(output, train_y)
metrics['train_loss'].append(loss.item())
loss.backward()
optimizer.step()
# Stochastic validation
model.eval()
likelihood.eval()
with torch.no_grad():
val_indices = torch.randperm(len(val_loader.dataset))[:int(0.1 * len(val_loader.dataset))]
val_loss = 0.0
val_labels = []
val_predictions = []
for idx in val_indices:
val_batch = val_loader.dataset[idx]
val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device)
val_y = torch.tensor([val_batch['label']], device=device)
val_output = model(val_x)
val_loss_batch = -mll(val_output, val_y).sum()
val_loss += val_loss_batch.item()
val_labels.append(val_y.item())
val_predictions.append(val_output.mean.argmax(dim=-1).item())
precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro')
auc_roc = roc_auc_score(label_binarize(val_labels, classes=range(n_classes)),
label_binarize(val_predictions, classes=range(n_classes)),
multi_class='ovr')
metrics['precision'].append(precision)
metrics['recall'].append(recall)
metrics['f1_score'].append(f1)
metrics['auc_roc'].append(auc_roc)
val_loss /= len(val_indices)
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
torch.save({'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch}, checkpoint_path)
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print(f"Early stopping triggered at epoch {epoch+1}")
break
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return model, likelihood, metrics
def evaluate_gp_model(test_loader, model, likelihood, n_classes=4):
model.eval()
likelihood.eval()
test_labels = []
test_predictions = []
with torch.no_grad():
for test_batch in test_loader:
test_x = test_batch['data'].reshape(test_batch['data'].size(0), -1).to(device)
test_y = test_batch['label'].to(device)
test_output = model(test_x)
test_labels.extend(test_y.tolist())
test_predictions.extend(test_output.mean.argmax(dim=-1).tolist())
precision, recall, f1, _ = precision_recall_fscore_support(test_labels, test_predictions, average='macro')
auc_roc = roc_auc_score(label_binarize(test_labels, classes=range(n_classes)),
label_binarize(test_predictions, classes=range(n_classes)),
multi_class='ovr')
metrics = {
'precision': precision,
'recall': recall,
'f1_score': f1,
'auc_roc': auc_roc
}
return metrics
def main():
# Device and drives
is_linux = False
is_hpc = False
is_internal = False
is_external = True
binary = False
# Input
is_tfs = True
# Database
database = 'mimic3'
# Initialize the focus
focus = 'thesis_results_database_multiclass'
# Initialize the file tag
file_tag = 'MIMIC_III'
# Image resolution
img_res = '128x128_float16'
# Data type: the type to convert the data into when it is loaded in
data_type = torch.float32
# Model type
model_type = torch.float32
# Create a PathMaster object
pathmaster = PathMaster(is_linux, is_hpc, is_tfs, is_internal, is_external, focus, file_tag, img_res)
# Image dimensions
img_channels = 1
img_size = 128
downsample = None
standardize = True
# Run parameters
n_epochs = 100
if binary:
n_classes = 2
else:
n_classes = 3
patience = round(n_epochs / 10) if n_epochs > 50 else 5
save = True
# Resume checkpoint
resume_checkpoint_path = None
# Data loading details
data_format = 'pt'
batch_size = 256
# Preprocess database data
test_loader = preprocess_data(database, batch_size, standardize, img_channels, img_size,
downsample, data_type, pathmaster, binary)
# Training and validation
start_time = time.time()
model, likelihood, metrics = train_gp_model(train_loader, val_loader, n_epochs,
n_classes, patience, save, pathmaster)
end_time = time.time()
time_passed = end_time - start_time
print('\nTraining and validation took %.2f minutes' % (time_passed / 60))
# Evaluation
start_time = time.time()
test_metrics = evaluate_gp_model(test_loader, model, likelihood, n_classes)
end_time = time.time()
time_passed = end_time - start_time
print('\nTesting took %.2f seconds' % time_passed)
print('Test Metrics:')
print('Precision: %.4f' % test_metrics['precision'])
print('Recall: %.4f' % test_metrics['recall'])
print('F1 Score: %.4f' % test_metrics['f1_score'])
print('AUC-ROC: %.4f' % test_metrics['auc_roc'])
if __name__ == '__main__':
main()