Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
B_ML_Project/main_darren_v1-8GJQ9R3.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
260 lines (216 sloc)
9.5 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Wed Apr 18 12:52:53 2024 | |
@author: lrmercadod | |
""" | |
import torch | |
import torch.nn as nn | |
import time | |
import datetime as dt | |
import gpytorch | |
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score | |
from sklearn.preprocessing import label_binarize | |
# Import my own functions and classes | |
from utils.pathmaster import PathMaster | |
from utils.dataloader import preprocess_data | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
num_latents = 6 # This should match the complexity of your data or the number of tasks | |
num_tasks = 4 # This should match the number of output classes or tasks | |
num_inducing_points = 50 # This is independent and should be sufficient for the input space | |
class MultitaskGPModel(gpytorch.models.ApproximateGP): | |
def __init__(self): | |
# Let's use a different set of inducing points for each latent function | |
inducing_points = torch.rand(num_latents, num_inducing_points, 128 * 128) # Assuming flattened 128x128 images | |
# We have to mark the CholeskyVariationalDistribution as batch | |
# so that we learn a variational distribution for each task | |
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution( | |
inducing_points.size(-2), batch_shape=torch.Size([num_latents]) | |
) | |
# We have to wrap the VariationalStrategy in a LMCVariationalStrategy | |
# so that the output will be a MultitaskMultivariateNormal rather than a batch output | |
variational_strategy = gpytorch.variational.LMCVariationalStrategy( | |
gpytorch.variational.VariationalStrategy( | |
self, inducing_points, variational_distribution, learn_inducing_locations=True | |
), | |
num_tasks=num_tasks, | |
num_latents=num_latents, | |
latent_dim=-1 | |
) | |
super().__init__(variational_strategy) | |
# The mean and covariance modules should be marked as batch | |
# so we learn a different set of hyperparameters | |
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents])) | |
self.covar_module = gpytorch.kernels.ScaleKernel( | |
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])), | |
batch_shape=torch.Size([num_latents]) | |
) | |
def forward(self, x): | |
mean_x = self.mean_module(x) | |
covar_x = self.covar_module(x) | |
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) | |
return latent_pred | |
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, | |
checkpoint_path='model_checkpoint.pt', resume_training=False): | |
model = MultitaskGPModel().to(device) | |
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | |
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) | |
start_epoch = 0 | |
if resume_training and os.path.exists(checkpoint_path): | |
checkpoint = torch.load(checkpoint_path) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
start_epoch = checkpoint.get('epoch', 0) | |
best_val_loss = float('inf') | |
epochs_no_improve = 0 | |
metrics = { | |
'precision': [], | |
'recall': [], | |
'f1_score': [], | |
'auc_roc': [], | |
'train_loss': [] | |
} | |
for epoch in range(start_epoch, num_iterations): | |
model.train() | |
likelihood.train() | |
for train_batch in train_loader: | |
optimizer.zero_grad() | |
train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device) | |
train_y = train_batch['label'].to(device) | |
output = model(train_x) | |
loss = -mll(output, train_y) | |
metrics['train_loss'].append(loss.item()) | |
loss.backward() | |
optimizer.step() | |
# Stochastic validation | |
model.eval() | |
likelihood.eval() | |
with torch.no_grad(): | |
val_indices = torch.randperm(len(val_loader.dataset))[:int(0.1 * len(val_loader.dataset))] | |
val_loss = 0.0 | |
val_labels = [] | |
val_predictions = [] | |
for idx in val_indices: | |
val_batch = val_loader.dataset[idx] | |
val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device) | |
val_y = torch.tensor([val_batch['label']], device=device) | |
val_output = model(val_x) | |
val_loss_batch = -mll(val_output, val_y).sum() | |
val_loss += val_loss_batch.item() | |
val_labels.append(val_y.item()) | |
val_predictions.append(val_output.mean.argmax(dim=-1).item()) | |
precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro') | |
auc_roc = roc_auc_score(label_binarize(val_labels, classes=range(n_classes)), | |
label_binarize(val_predictions, classes=range(n_classes)), | |
multi_class='ovr') | |
metrics['precision'].append(precision) | |
metrics['recall'].append(recall) | |
metrics['f1_score'].append(f1) | |
metrics['auc_roc'].append(auc_roc) | |
val_loss /= len(val_indices) | |
if val_loss < best_val_loss: | |
best_val_loss = val_loss | |
epochs_no_improve = 0 | |
torch.save({'model_state_dict': model.state_dict(), | |
'likelihood_state_dict': likelihood.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'epoch': epoch}, checkpoint_path) | |
else: | |
epochs_no_improve += 1 | |
if epochs_no_improve >= patience: | |
print(f"Early stopping triggered at epoch {epoch+1}") | |
break | |
if os.path.exists(checkpoint_path): | |
checkpoint = torch.load(checkpoint_path) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
return model, likelihood, metrics | |
def evaluate_gp_model(test_loader, model, likelihood, n_classes=4): | |
model.eval() | |
likelihood.eval() | |
test_labels = [] | |
test_predictions = [] | |
with torch.no_grad(): | |
for test_batch in test_loader: | |
test_x = test_batch['data'].reshape(test_batch['data'].size(0), -1).to(device) | |
test_y = test_batch['label'].to(device) | |
test_output = model(test_x) | |
test_labels.extend(test_y.tolist()) | |
test_predictions.extend(test_output.mean.argmax(dim=-1).tolist()) | |
precision, recall, f1, _ = precision_recall_fscore_support(test_labels, test_predictions, average='macro') | |
auc_roc = roc_auc_score(label_binarize(test_labels, classes=range(n_classes)), | |
label_binarize(test_predictions, classes=range(n_classes)), | |
multi_class='ovr') | |
metrics = { | |
'precision': precision, | |
'recall': recall, | |
'f1_score': f1, | |
'auc_roc': auc_roc | |
} | |
return metrics | |
def main(): | |
# Device and drives | |
is_linux = False | |
is_hpc = False | |
is_internal = False | |
is_external = True | |
binary = False | |
# Input | |
is_tfs = True | |
# Database | |
database = 'mimic3' | |
# Initialize the focus | |
focus = 'thesis_results_database_multiclass' | |
# Initialize the file tag | |
file_tag = 'MIMIC_III' | |
# Image resolution | |
img_res = '128x128_float16' | |
# Data type: the type to convert the data into when it is loaded in | |
data_type = torch.float32 | |
# Model type | |
model_type = torch.float32 | |
# Create a PathMaster object | |
pathmaster = PathMaster(is_linux, is_hpc, is_tfs, is_internal, is_external, focus, file_tag, img_res) | |
# Image dimensions | |
img_channels = 1 | |
img_size = 128 | |
downsample = None | |
standardize = True | |
# Run parameters | |
n_epochs = 100 | |
if binary: | |
n_classes = 2 | |
else: | |
n_classes = 3 | |
patience = round(n_epochs / 10) if n_epochs > 50 else 5 | |
save = True | |
# Resume checkpoint | |
resume_checkpoint_path = None | |
# Data loading details | |
data_format = 'pt' | |
batch_size = 256 | |
# Preprocess database data | |
test_loader = preprocess_data(database, batch_size, standardize, img_channels, img_size, | |
downsample, data_type, pathmaster, binary) | |
# Training and validation | |
start_time = time.time() | |
model, likelihood, metrics = train_gp_model(train_loader, val_loader, n_epochs, | |
n_classes, patience, save, pathmaster) | |
end_time = time.time() | |
time_passed = end_time - start_time | |
print('\nTraining and validation took %.2f minutes' % (time_passed / 60)) | |
# Evaluation | |
start_time = time.time() | |
test_metrics = evaluate_gp_model(test_loader, model, likelihood, n_classes) | |
end_time = time.time() | |
time_passed = end_time - start_time | |
print('\nTesting took %.2f seconds' % time_passed) | |
print('Test Metrics:') | |
print('Precision: %.4f' % test_metrics['precision']) | |
print('Recall: %.4f' % test_metrics['recall']) | |
print('F1 Score: %.4f' % test_metrics['f1_score']) | |
print('AUC-ROC: %.4f' % test_metrics['auc_roc']) | |
if __name__ == '__main__': | |
main() |