-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Interpretation and addition of changes based on Darren Codes and Cass…
…eys codes I did an addition of some changes that Cassey did before to ensure the availability of the code to run in my new environment, the change is saving the data loader stepping process, allowing to understand really well and restart the model without losing the previous downloading time. Darren code is a renovated version of our previous methods that ensures flexibility and ability to run the codes in any environment. The reference for now is a code in debugging, later it would be transpose in a more suitable definition structure.
- Loading branch information
Luis Roberto Mercado Diaz
committed
Apr 20, 2024
1 parent
861e742
commit bfe5b1c
Showing
37 changed files
with
5,652 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,260 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Wed Apr 18 12:52:53 2024 | ||
@author: lrmercadod | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
import time | ||
import datetime as dt | ||
import gpytorch | ||
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score | ||
from sklearn.preprocessing import label_binarize | ||
|
||
# Import my own functions and classes | ||
from utils.pathmaster import PathMaster | ||
from utils.dataloader import preprocess_data | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
num_latents = 6 # This should match the complexity of your data or the number of tasks | ||
num_tasks = 4 # This should match the number of output classes or tasks | ||
num_inducing_points = 50 # This is independent and should be sufficient for the input space | ||
|
||
class MultitaskGPModel(gpytorch.models.ApproximateGP): | ||
def __init__(self): | ||
# Let's use a different set of inducing points for each latent function | ||
inducing_points = torch.rand(num_latents, num_inducing_points, 128 * 128) # Assuming flattened 128x128 images | ||
|
||
# We have to mark the CholeskyVariationalDistribution as batch | ||
# so that we learn a variational distribution for each task | ||
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution( | ||
inducing_points.size(-2), batch_shape=torch.Size([num_latents]) | ||
) | ||
|
||
# We have to wrap the VariationalStrategy in a LMCVariationalStrategy | ||
# so that the output will be a MultitaskMultivariateNormal rather than a batch output | ||
variational_strategy = gpytorch.variational.LMCVariationalStrategy( | ||
gpytorch.variational.VariationalStrategy( | ||
self, inducing_points, variational_distribution, learn_inducing_locations=True | ||
), | ||
num_tasks=num_tasks, | ||
num_latents=num_latents, | ||
latent_dim=-1 | ||
) | ||
|
||
super().__init__(variational_strategy) | ||
|
||
# The mean and covariance modules should be marked as batch | ||
# so we learn a different set of hyperparameters | ||
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents])) | ||
self.covar_module = gpytorch.kernels.ScaleKernel( | ||
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])), | ||
batch_shape=torch.Size([num_latents]) | ||
) | ||
|
||
def forward(self, x): | ||
mean_x = self.mean_module(x) | ||
covar_x = self.covar_module(x) | ||
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) | ||
return latent_pred | ||
|
||
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, | ||
checkpoint_path='model_checkpoint.pt', resume_training=False): | ||
model = MultitaskGPModel().to(device) | ||
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) | ||
|
||
start_epoch = 0 | ||
if resume_training and os.path.exists(checkpoint_path): | ||
checkpoint = torch.load(checkpoint_path) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
start_epoch = checkpoint.get('epoch', 0) | ||
|
||
best_val_loss = float('inf') | ||
epochs_no_improve = 0 | ||
|
||
metrics = { | ||
'precision': [], | ||
'recall': [], | ||
'f1_score': [], | ||
'auc_roc': [], | ||
'train_loss': [] | ||
} | ||
|
||
for epoch in range(start_epoch, num_iterations): | ||
model.train() | ||
likelihood.train() | ||
for train_batch in train_loader: | ||
optimizer.zero_grad() | ||
train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device) | ||
train_y = train_batch['label'].to(device) | ||
output = model(train_x) | ||
loss = -mll(output, train_y) | ||
metrics['train_loss'].append(loss.item()) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# Stochastic validation | ||
model.eval() | ||
likelihood.eval() | ||
with torch.no_grad(): | ||
val_indices = torch.randperm(len(val_loader.dataset))[:int(0.1 * len(val_loader.dataset))] | ||
val_loss = 0.0 | ||
val_labels = [] | ||
val_predictions = [] | ||
for idx in val_indices: | ||
val_batch = val_loader.dataset[idx] | ||
val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device) | ||
val_y = torch.tensor([val_batch['label']], device=device) | ||
val_output = model(val_x) | ||
val_loss_batch = -mll(val_output, val_y).sum() | ||
val_loss += val_loss_batch.item() | ||
val_labels.append(val_y.item()) | ||
val_predictions.append(val_output.mean.argmax(dim=-1).item()) | ||
|
||
precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro') | ||
auc_roc = roc_auc_score(label_binarize(val_labels, classes=range(n_classes)), | ||
label_binarize(val_predictions, classes=range(n_classes)), | ||
multi_class='ovr') | ||
|
||
metrics['precision'].append(precision) | ||
metrics['recall'].append(recall) | ||
metrics['f1_score'].append(f1) | ||
metrics['auc_roc'].append(auc_roc) | ||
val_loss /= len(val_indices) | ||
|
||
if val_loss < best_val_loss: | ||
best_val_loss = val_loss | ||
epochs_no_improve = 0 | ||
torch.save({'model_state_dict': model.state_dict(), | ||
'likelihood_state_dict': likelihood.state_dict(), | ||
'optimizer_state_dict': optimizer.state_dict(), | ||
'epoch': epoch}, checkpoint_path) | ||
else: | ||
epochs_no_improve += 1 | ||
if epochs_no_improve >= patience: | ||
print(f"Early stopping triggered at epoch {epoch+1}") | ||
break | ||
|
||
if os.path.exists(checkpoint_path): | ||
checkpoint = torch.load(checkpoint_path) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
|
||
return model, likelihood, metrics | ||
|
||
def evaluate_gp_model(test_loader, model, likelihood, n_classes=4): | ||
model.eval() | ||
likelihood.eval() | ||
test_labels = [] | ||
test_predictions = [] | ||
|
||
with torch.no_grad(): | ||
for test_batch in test_loader: | ||
test_x = test_batch['data'].reshape(test_batch['data'].size(0), -1).to(device) | ||
test_y = test_batch['label'].to(device) | ||
test_output = model(test_x) | ||
test_labels.extend(test_y.tolist()) | ||
test_predictions.extend(test_output.mean.argmax(dim=-1).tolist()) | ||
|
||
precision, recall, f1, _ = precision_recall_fscore_support(test_labels, test_predictions, average='macro') | ||
auc_roc = roc_auc_score(label_binarize(test_labels, classes=range(n_classes)), | ||
label_binarize(test_predictions, classes=range(n_classes)), | ||
multi_class='ovr') | ||
|
||
metrics = { | ||
'precision': precision, | ||
'recall': recall, | ||
'f1_score': f1, | ||
'auc_roc': auc_roc | ||
} | ||
|
||
return metrics | ||
|
||
def main(): | ||
# Device and drives | ||
is_linux = False | ||
is_hpc = False | ||
is_internal = False | ||
is_external = True | ||
binary = False | ||
|
||
# Input | ||
is_tfs = True | ||
|
||
# Database | ||
database = 'mimic3' | ||
|
||
# Initialize the focus | ||
focus = 'thesis_results_database_multiclass' | ||
|
||
# Initialize the file tag | ||
file_tag = 'MIMIC_III' | ||
|
||
# Image resolution | ||
img_res = '128x128_float16' | ||
|
||
# Data type: the type to convert the data into when it is loaded in | ||
data_type = torch.float32 | ||
|
||
# Model type | ||
model_type = torch.float32 | ||
|
||
# Create a PathMaster object | ||
pathmaster = PathMaster(is_linux, is_hpc, is_tfs, is_internal, is_external, focus, file_tag, img_res) | ||
|
||
# Image dimensions | ||
img_channels = 1 | ||
img_size = 128 | ||
downsample = None | ||
standardize = True | ||
|
||
# Run parameters | ||
n_epochs = 100 | ||
if binary: | ||
n_classes = 2 | ||
else: | ||
n_classes = 3 | ||
patience = round(n_epochs / 10) if n_epochs > 50 else 5 | ||
save = True | ||
|
||
# Resume checkpoint | ||
resume_checkpoint_path = None | ||
|
||
# Data loading details | ||
data_format = 'pt' | ||
batch_size = 256 | ||
|
||
# Preprocess database data | ||
test_loader = preprocess_data(database, batch_size, standardize, img_channels, img_size, | ||
downsample, data_type, pathmaster, binary) | ||
|
||
# Training and validation | ||
start_time = time.time() | ||
model, likelihood, metrics = train_gp_model(train_loader, val_loader, n_epochs, | ||
n_classes, patience, save, pathmaster) | ||
end_time = time.time() | ||
time_passed = end_time - start_time | ||
print('\nTraining and validation took %.2f minutes' % (time_passed / 60)) | ||
|
||
# Evaluation | ||
start_time = time.time() | ||
test_metrics = evaluate_gp_model(test_loader, model, likelihood, n_classes) | ||
end_time = time.time() | ||
time_passed = end_time - start_time | ||
print('\nTesting took %.2f seconds' % time_passed) | ||
|
||
print('Test Metrics:') | ||
print('Precision: %.4f' % test_metrics['precision']) | ||
print('Recall: %.4f' % test_metrics['recall']) | ||
print('F1 Score: %.4f' % test_metrics['f1_score']) | ||
print('AUC-ROC: %.4f' % test_metrics['auc_roc']) | ||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.