-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The R drive needs to be defined well, this definition some times works, but the grove method ensures the source connection all the time.
- Loading branch information
Showing
2 changed files
with
363 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,361 @@ | ||
import os | ||
import torch | ||
import gpytorch | ||
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score | ||
from sklearn.preprocessing import label_binarize | ||
from torch.utils.data import Dataset, DataLoader | ||
import numpy as np | ||
import random | ||
import time | ||
import matplotlib.pyplot as plt | ||
|
||
# Seeds | ||
torch.manual_seed(42) | ||
np.random.seed(42) | ||
random.seed(42) | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
num_latents = 6 # This should match the complexity of your data or the number of tasks | ||
num_tasks = 4 # This should match the number of output classes or tasks | ||
num_inducing_points = 50 # This is independent and should be sufficient for the input space | ||
|
||
class MultitaskGPModel(gpytorch.models.ApproximateGP): | ||
def __init__(self): | ||
# Let's use a different set of inducing points for each latent function | ||
inducing_points = torch.rand(num_latents, num_inducing_points, 128 * 128) # Assuming flattened 128x128 images | ||
|
||
# We have to mark the CholeskyVariationalDistribution as batch | ||
# so that we learn a variational distribution for each task | ||
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution( | ||
inducing_points.size(-2), batch_shape=torch.Size([num_latents]) | ||
) | ||
|
||
# We have to wrap the VariationalStrategy in a LMCVariationalStrategy | ||
# so that the output will be a MultitaskMultivariateNormal rather than a batch output | ||
variational_strategy = gpytorch.variational.LMCVariationalStrategy( | ||
gpytorch.variational.VariationalStrategy( | ||
self, inducing_points, variational_distribution, learn_inducing_locations=True | ||
), | ||
num_tasks=num_tasks, | ||
num_latents=num_latents, | ||
latent_dim=-1 | ||
) | ||
|
||
super().__init__(variational_strategy) | ||
|
||
# The mean and covariance modules should be marked as batch | ||
# so we learn a different set of hyperparameters | ||
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents])) | ||
self.covar_module = gpytorch.kernels.ScaleKernel( | ||
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])), | ||
batch_shape=torch.Size([num_latents]) | ||
) | ||
|
||
def forward(self, x): | ||
mean_x = self.mean_module(x) | ||
covar_x = self.covar_module(x) | ||
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) | ||
return latent_pred | ||
|
||
class CustomDataset(Dataset): | ||
def __init__(self, data_path, labels_path, binary=False, start_idx=0): | ||
self.data_path = data_path | ||
self.labels_path = labels_path | ||
self.binary = binary | ||
self.start_idx = start_idx | ||
self.segment_names, self.labels = self.extract_segment_names_and_labels() | ||
|
||
def __len__(self): | ||
return len(self.segment_names) | ||
|
||
def __getitem__(self, idx): | ||
actual_idx = (idx + self.start_idx) % len(self.segment_names) | ||
segment_name = self.segment_names[actual_idx] | ||
label = self.labels[segment_name] | ||
data_tensor = torch.load(os.path.join(self.data_path, segment_name + '.pt')) | ||
return {'data': data_tensor, 'label': label, 'segment_name': segment_name} | ||
|
||
def extract_segment_names_and_labels(self): | ||
segment_names = [] | ||
labels = {} | ||
|
||
with open(self.labels_path, 'r') as file: | ||
lines = file.readlines() | ||
for line in lines[1:]: # Skip the header line | ||
segment_name, label = line.strip().split(',') | ||
label = int(float(label)) # Convert the label to float first, then to int | ||
if self.binary and label == 2: | ||
label = 0 # Convert PAC/PVC to non-AF (0) for binary classification | ||
segment_names.append(segment_name) | ||
labels[segment_name] = label | ||
|
||
return segment_names, labels | ||
|
||
def set_start_idx(self, index): | ||
self.start_idx = index | ||
|
||
def save_checkpoint(self, checkpoint_path): | ||
checkpoint = { | ||
'segment_names': self.segment_names, | ||
'labels': self.labels, | ||
'start_idx': self.start_idx | ||
} | ||
torch.save(checkpoint, checkpoint_path) | ||
|
||
def load_checkpoint(self, checkpoint_path): | ||
checkpoint = torch.load(checkpoint_path) | ||
self.segment_names = checkpoint['segment_names'] | ||
self.labels = checkpoint['labels'] | ||
self.start_idx = checkpoint['start_idx'] | ||
|
||
def load_data(data_path, labels_path, batch_size, binary=False, start_idx=0): | ||
dataset = CustomDataset(data_path, labels_path, binary, start_idx) | ||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) | ||
return dataloader | ||
|
||
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, | ||
checkpoint_path='model_checkpoint.pt', data_checkpoint_path='data_checkpoint.pt', | ||
resume_training=False, plot_path='training_plot.png'): | ||
model = MultitaskGPModel().to(device) | ||
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) | ||
|
||
start_epoch = 0 | ||
if resume_training and os.path.exists(checkpoint_path): | ||
checkpoint = torch.load(checkpoint_path) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
start_epoch = checkpoint.get('epoch', 0) | ||
|
||
best_val_loss = float('inf') | ||
epochs_no_improve = 0 | ||
|
||
metrics = { | ||
'precision': [], | ||
'recall': [], | ||
'f1_score': [], | ||
'auc_roc': [], | ||
'train_loss': [] | ||
} | ||
|
||
# Initialize lists to store metrics for plotting | ||
train_losses = [] | ||
val_losses = [] | ||
val_precisions = [] | ||
val_recalls = [] | ||
val_f1_scores = [] | ||
val_auc_rocs = [] | ||
|
||
for epoch in range(start_epoch, num_iterations): | ||
model.train() | ||
likelihood.train() | ||
for train_batch in train_loader: | ||
optimizer.zero_grad() | ||
train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device) | ||
train_y = train_batch['label'].to(device) | ||
output = model(train_x) | ||
loss = -mll(output, train_y) | ||
metrics['train_loss'].append(loss.item()) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# Append metrics to lists for plotting | ||
train_losses.append(np.mean(metrics['train_loss'])) | ||
val_losses.append(val_loss) | ||
val_precisions.append(precision) | ||
val_recalls.append(recall) | ||
val_f1_scores.append(f1) | ||
val_auc_rocs.append(auc_roc) | ||
|
||
# Stochastic validation | ||
model.eval() | ||
likelihood.eval() | ||
with torch.no_grad(): | ||
val_indices = torch.randperm(len(val_loader.dataset))[:int(0.1 * len(val_loader.dataset))] | ||
val_loss = 0.0 | ||
val_labels = [] | ||
val_predictions = [] | ||
for idx in val_indices: | ||
val_batch = val_loader.dataset[idx] | ||
val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device) | ||
val_y = torch.tensor([val_batch['label']], device=device) | ||
val_output = model(val_x) | ||
val_loss_batch = -mll(val_output, val_y).sum() | ||
val_loss += val_loss_batch.item() | ||
val_labels.append(val_y.item()) | ||
val_predictions.append(val_output.mean.argmax(dim=-1).item()) | ||
|
||
precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro') | ||
auc_roc = roc_auc_score(label_binarize(val_labels, classes=range(n_classes)), | ||
label_binarize(val_predictions, classes=range(n_classes)), | ||
multi_class='ovr') | ||
|
||
metrics['precision'].append(precision) | ||
metrics['recall'].append(recall) | ||
metrics['f1_score'].append(f1) | ||
metrics['auc_roc'].append(auc_roc) | ||
val_loss /= len(val_indices) | ||
|
||
# Plot metrics | ||
plt.figure(figsize=(12, 8)) | ||
plt.subplot(2, 2, 1) | ||
plt.plot(train_losses, label='Training Loss') | ||
plt.plot(val_losses, label='Validation Loss') | ||
plt.legend() | ||
plt.title('Loss') | ||
plt.xlabel('Epoch') | ||
plt.ylabel('Loss') | ||
|
||
plt.subplot(2, 2, 2) | ||
plt.plot(val_precisions, label='Precision') | ||
plt.plot(val_recalls, label='Recall') | ||
plt.plot(val_f1_scores, label='F1 Score') | ||
plt.legend() | ||
plt.title('Validation Metrics') | ||
plt.xlabel('Epoch') | ||
plt.ylabel('Metric') | ||
|
||
plt.subplot(2, 2, 3) | ||
plt.plot(val_auc_rocs, label='AUC-ROC') | ||
plt.legend() | ||
plt.title('Validation AUC-ROC') | ||
plt.xlabel('Epoch') | ||
plt.ylabel('AUC-ROC') | ||
|
||
plt.tight_layout() | ||
plt.savefig(plot_path) | ||
plt.close() | ||
|
||
if val_loss < best_val_loss: | ||
best_val_loss = val_loss | ||
epochs_no_improve = 0 | ||
torch.save({'model_state_dict': model.state_dict(), | ||
'likelihood_state_dict': likelihood.state_dict(), | ||
'optimizer_state_dict': optimizer.state_dict(), | ||
'epoch': epoch}, checkpoint_path) | ||
else: | ||
epochs_no_improve += 1 | ||
if epochs_no_improve >= patience: | ||
print(f"Early stopping triggered at epoch {epoch+1}") | ||
break | ||
|
||
if os.path.exists(checkpoint_path): | ||
checkpoint = torch.load(checkpoint_path) | ||
model.load_state_dict(checkpoint['model_state_dict']) | ||
likelihood.load_state_dict(checkpoint['likelihood_state_dict']) | ||
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | ||
|
||
# Save final model weights and other information | ||
torch.save({ | ||
'model_state_dict': model.state_dict(), | ||
'likelihood_state_dict': likelihood.state_dict(), | ||
'optimizer_state_dict': optimizer.state_dict(), | ||
'train_losses': train_losses, | ||
'val_losses': val_losses, | ||
'val_precisions': val_precisions, | ||
'val_recalls': val_recalls, | ||
'val_f1_scores': val_f1_scores, | ||
'val_auc_rocs': val_auc_rocs | ||
}, 'final_model_info.pt') | ||
|
||
return model, likelihood, metrics | ||
|
||
def evaluate_gp_model(test_loader, model, likelihood, n_classes=4): | ||
model.eval() | ||
likelihood.eval() | ||
test_labels = [] | ||
test_predictions = [] | ||
|
||
with torch.no_grad(): | ||
for test_batch in test_loader: | ||
test_x = test_batch['data'].reshape(test_batch['data'].size(0), -1).to(device) | ||
test_y = test_batch['label'].to(device) | ||
test_output = model(test_x) | ||
test_labels.extend(test_y.tolist()) | ||
test_predictions.extend(test_output.mean.argmax(dim=-1).tolist()) | ||
|
||
precision, recall, f1, _ = precision_recall_fscore_support(test_labels, test_predictions, average='macro') | ||
auc_roc = roc_auc_score(label_binarize(test_labels, classes=range(n_classes)), | ||
label_binarize(test_predictions, classes=range(n_classes)), | ||
multi_class='ovr') | ||
|
||
metrics = { | ||
'precision': precision, | ||
'recall': recall, | ||
'f1_score': f1, | ||
'auc_roc': auc_roc | ||
} | ||
|
||
return metrics | ||
|
||
def main(): | ||
print("Step 1: Loading paths and parameters") | ||
# Paths | ||
base_path = r"\\grove.ad.uconn.edu\\research\\ENGR_Chon\Darren\\NIH_Pulsewatch" | ||
smote_type = 'Cassey5k_SMOTE' | ||
split = 'holdout_60_10_30' | ||
data_path_train = os.path.join(base_path, "TFS_pt", smote_type, split, "train") | ||
data_path_val = os.path.join(base_path, "TFS_pt", smote_type, split, "validate") | ||
data_path_test = os.path.join(base_path, "TFS_pt", smote_type, split, "test") | ||
labels_path_train = os.path.join(base_path, "TFS_pt", smote_type, split, "Cassey5k_SMOTE_train_names_labels.csv") | ||
labels_path_val = os.path.join(base_path, "TFS_pt", smote_type, split, "Cassey5k_SMOTE_validate_names_labels.csv") | ||
labels_path_test = os.path.join(base_path, "TFS_pt", smote_type, split, "Cassey5k_SMOTE_test_names_labels.csv") | ||
|
||
# Parameters | ||
binary = False | ||
n_epochs = 100 | ||
if binary: | ||
n_classes = 2 | ||
else: | ||
n_classes = 3 | ||
patience = round(n_epochs / 10) if n_epochs > 50 else 5 | ||
resume_checkpoint_path = None | ||
batch_size = 256 | ||
|
||
print("Step 2: Loading data") | ||
# Data loading | ||
train_loader = load_data(data_path_train, labels_path_train, batch_size, binary) | ||
val_loader = load_data(data_path_val, labels_path_val, batch_size, binary) | ||
test_loader = load_data(data_path_test, labels_path_test, batch_size, binary) | ||
|
||
print("Step 3: Loading data checkpoints") | ||
# Data loading with checkpointing | ||
data_checkpoint_path = 'data_checkpoint.pt' | ||
if os.path.exists(data_checkpoint_path): | ||
train_loader.dataset.load_checkpoint(data_checkpoint_path) | ||
val_loader.dataset.load_checkpoint(data_checkpoint_path) | ||
test_loader.dataset.load_checkpoint(data_checkpoint_path) | ||
|
||
print("Step 4: Training and validation") | ||
# Training and validation with checkpointing and plotting | ||
model_checkpoint_path = 'model_checkpoint.pt' | ||
plot_path = 'training_plot.png' | ||
start_time = time.time() | ||
model, likelihood, metrics = train_gp_model(train_loader, val_loader, n_epochs, | ||
n_classes, patience, | ||
model_checkpoint_path, data_checkpoint_path, | ||
resume_checkpoint_path is not None, plot_path) | ||
end_time = time.time() | ||
time_passed = end_time - start_time | ||
print('\nTraining and validation took %.2f minutes' % (time_passed / 60)) | ||
|
||
print("Step 5: Evaluation") | ||
# Evaluation | ||
start_time = time.time() | ||
test_metrics = evaluate_gp_model(test_loader, model, likelihood, n_classes) | ||
end_time = time.time() | ||
time_passed = end_time - start_time | ||
print('\nTesting took %.2f seconds' % time_passed) | ||
|
||
print("Step 6: Printing test metrics") | ||
print('Test Metrics:') | ||
print('Precision: %.4f' % test_metrics['precision']) | ||
print('Recall: %.4f' % test_metrics['recall']) | ||
print('F1 Score: %.4f' % test_metrics['f1_score']) | ||
print('AUC-ROC: %.4f' % test_metrics['auc_roc']) | ||
|
||
if __name__ == '__main__': | ||
main() |