Skip to content

Commit

Permalink
Merge branch 'main' into Luis
Browse files Browse the repository at this point in the history
  • Loading branch information
lrm22005 committed Apr 21, 2024
2 parents 5de1526 + 09dec70 commit ec87427
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 284 deletions.
4 changes: 2 additions & 2 deletions GP_Original_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False):
saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis"
else:
# R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch
base_path = "R:\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch"
labels_base_path = "R:\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn"
base_path = r"\\grove.ad.uconn.edu\\research\\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch"
labels_base_path = r"\\grove.ad.uconn.edu\\research\\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn"
saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case"
if data_format == 'csv':
data_path = os.path.join(base_path, "TFS_csv")
Expand Down
282 changes: 0 additions & 282 deletions main_darren_v1.py
Original file line number Diff line number Diff line change
@@ -1,282 +0,0 @@
import os
import torch
import gpytorch
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from sklearn.preprocessing import label_binarize
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import time

# Seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_latents = 6 # This should match the complexity of your data or the number of tasks
num_tasks = 4 # This should match the number of output classes or tasks
num_inducing_points = 50 # This is independent and should be sufficient for the input space

class MultitaskGPModel(gpytorch.models.ApproximateGP):
def __init__(self):
# Let's use a different set of inducing points for each latent function
inducing_points = torch.rand(num_latents, num_inducing_points, 128 * 128) # Assuming flattened 128x128 images

# We have to mark the CholeskyVariationalDistribution as batch
# so that we learn a variational distribution for each task
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
inducing_points.size(-2), batch_shape=torch.Size([num_latents])
)

# We have to wrap the VariationalStrategy in a LMCVariationalStrategy
# so that the output will be a MultitaskMultivariateNormal rather than a batch output
variational_strategy = gpytorch.variational.LMCVariationalStrategy(
gpytorch.variational.VariationalStrategy(
self, inducing_points, variational_distribution, learn_inducing_locations=True
),
num_tasks=num_tasks,
num_latents=num_latents,
latent_dim=-1
)

super().__init__(variational_strategy)

# The mean and covariance modules should be marked as batch
# so we learn a different set of hyperparameters
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
batch_shape=torch.Size([num_latents])
)

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
return latent_pred

class CustomDataset(Dataset):
def __init__(self, data_path, labels_path, binary=False):
self.data_path = data_path
self.labels_path = labels_path
self.binary = binary
self.segment_names, self.labels = self.extract_segment_names_and_labels()

def __len__(self):
return len(self.segment_names)

def __getitem__(self, idx):
segment_name = self.segment_names[idx]
label = self.labels[segment_name]
data_tensor = torch.load(os.path.join(self.data_path, segment_name + '.pt'))
return {'data': data_tensor, 'label': label, 'segment_name': segment_name}

def extract_segment_names_and_labels(self):
segment_names = []
labels = {}

with open(self.labels_path, 'r') as file:
lines = file.readlines()
for line in lines[1:]: # Skip the header line
segment_name, label = line.strip().split(',')
label = int(float(label)) # Convert the label to float first, then to int
if self.binary and label == 2:
label = 0 # Convert PAC/PVC to non-AF (0) for binary classification
segment_names.append(segment_name)
labels[segment_name] = label

return segment_names, labels

def load_data(data_path, labels_path, batch_size, binary=False):
dataset = CustomDataset(data_path, labels_path, binary)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataloader

def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10,
checkpoint_path='model_checkpoint.pt', data_checkpoint_path='data_checkpoint.pt',
resume_training=False, plot_path='training_plot.png'):
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset))

start_epoch = 0
if resume_training and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint.get('epoch', 0)

best_val_loss = float('inf')
epochs_no_improve = 0

metrics = {
'precision': [],
'recall': [],
'f1_score': [],
'auc_roc': [],
'train_loss': []
}

for epoch in range(start_epoch, num_iterations):
model.train()
likelihood.train()
for train_batch in train_loader:
optimizer.zero_grad()
train_x = train_batch['data'].reshape(train_batch['data'].size(0), -1).to(device)
train_y = train_batch['label'].to(device)
output = model(train_x)
loss = -mll(output, train_y)
metrics['train_loss'].append(loss.item())
loss.backward()
optimizer.step()

# Stochastic validation
model.eval()
likelihood.eval()
val_loss = 0.0 # Initialize val_loss here
with torch.no_grad():
val_indices = torch.randperm(len(val_loader.dataset))[:int(0.1 * len(val_loader.dataset))]
val_labels = []
val_predictions = []
for idx in val_indices:
val_batch = val_loader.dataset[idx]
val_x = val_batch['data'].reshape(-1).unsqueeze(0).to(device)
val_y = torch.tensor([val_batch['label']], device=device)
val_output = model(val_x)
val_loss_batch = -mll(val_output, val_y).sum()
val_loss += val_loss_batch.item()
val_labels.append(val_y.item())
val_predictions.append(val_output.mean.argmax(dim=-1).item())

precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_predictions, average='macro')
auc_roc = roc_auc_score(label_binarize(val_labels, classes=range(n_classes)),
label_binarize(val_predictions, classes=range(n_classes)),
multi_class='ovr')

metrics['precision'].append(precision)
metrics['recall'].append(recall)
metrics['f1_score'].append(f1)
metrics['auc_roc'].append(auc_roc)
val_loss /= len(val_indices)

if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
torch.save({'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch}, checkpoint_path)
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print(f"Early stopping triggered at epoch {epoch+1}")
break

if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

return model, likelihood, metrics

def evaluate_gp_model(test_loader, model, likelihood, n_classes=4):
model.eval()
likelihood.eval()
test_labels = []
test_predictions = []

with torch.no_grad():
for test_batch in test_loader:
test_x = test_batch['data'].reshape(test_batch['data'].size(0), -1).to(device)
test_y = test_batch['label'].to(device)
test_output = model(test_x)
test_labels.extend(test_y.tolist())
test_predictions.extend(test_output.mean.argmax(dim=-1).tolist())

precision, recall, f1, _ = precision_recall_fscore_support(test_labels, test_predictions, average='macro')
auc_roc = roc_auc_score(label_binarize(test_labels, classes=range(n_classes)),
label_binarize(test_predictions, classes=range(n_classes)),
multi_class='ovr')

metrics = {
'precision': precision,
'recall': recall,
'f1_score': f1,
'auc_roc': auc_roc
}

return metrics

def main():
print("Step 1: Loading paths and parameters")
# Paths
base_path = r"\\grove.ad.uconn.edu\\research\\ENGR_Chon\Darren\\NIH_Pulsewatch"
smote_type = 'Cassey5k_SMOTE'
split = 'holdout_60_10_30'
data_path_train = os.path.join(base_path, "TFS_pt", smote_type, split, "train")
data_path_val = os.path.join(base_path, "TFS_pt", smote_type, split, "validate")
data_path_test = os.path.join(base_path, "TFS_pt", smote_type, split, "test")
labels_path_train = os.path.join(base_path, "TFS_pt", smote_type, split, "Cassey5k_SMOTE_train_names_labels.csv")
labels_path_val = os.path.join(base_path, "TFS_pt", smote_type, split, "Cassey5k_SMOTE_validate_names_labels.csv")
labels_path_test = os.path.join(base_path, "TFS_pt", smote_type, split, "Cassey5k_SMOTE_test_names_labels.csv")

# Parameters
binary = False
n_epochs = 100
if binary:
n_classes = 2
else:
n_classes = 3
patience = round(n_epochs / 10) if n_epochs > 50 else 5
resume_checkpoint_path = None
batch_size = 256

print("Step 2: Loading data")
# Data loading
train_loader = load_data(data_path_train, labels_path_train, batch_size, binary)
val_loader = load_data(data_path_val, labels_path_val, batch_size, binary)
test_loader = load_data(data_path_test, labels_path_test, batch_size, binary)

print("Step 3: Loading data checkpoints")
# Data loading with checkpointing
data_checkpoint_path = 'data_checkpoint.pt'
if os.path.exists(data_checkpoint_path):
train_loader.dataset.load_checkpoint(data_checkpoint_path)
val_loader.dataset.load_checkpoint(data_checkpoint_path)
test_loader.dataset.load_checkpoint(data_checkpoint_path)

print("Step 4: Training and validation")
# Training and validation with checkpointing and plotting
model_checkpoint_path = 'model_checkpoint.pt'
plot_path = 'training_plot.png'
start_time = time.time()
model, likelihood, metrics = train_gp_model(train_loader, val_loader, n_epochs,
n_classes, patience,
model_checkpoint_path, data_checkpoint_path,
resume_checkpoint_path is not None, plot_path)
end_time = time.time()
time_passed = end_time - start_time
print('\nTraining and validation took %.2f minutes' % (time_passed / 60))

print("Step 5: Evaluation")
# Evaluation
start_time = time.time()
test_metrics = evaluate_gp_model(test_loader, model, likelihood, n_classes)
end_time = time.time()
time_passed = end_time - start_time
print('\nTesting took %.2f seconds' % time_passed)

print("Step 6: Printing test metrics")
print('Test Metrics:')
print('Precision: %.4f' % test_metrics['precision'])
print('Recall: %.4f' % test_metrics['recall'])
print('F1 Score: %.4f' % test_metrics['f1_score'])
print('AUC-ROC: %.4f' % test_metrics['auc_roc'])

if __name__ == '__main__':
main()

0 comments on commit ec87427

Please sign in to comment.