Skip to content

Commit

Permalink
UPDATES
Browse files Browse the repository at this point in the history
  • Loading branch information
lrm22005 committed Apr 21, 2024
1 parent a84d01e commit 5de1526
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions main_darren_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def load_data(data_path, labels_path, batch_size, binary=False):
return dataloader

def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10,
checkpoint_path='model_checkpoint.pt', resume_training=False):
checkpoint_path='model_checkpoint.pt', data_checkpoint_path='data_checkpoint.pt',
resume_training=False, plot_path='training_plot.png'):
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
Expand Down Expand Up @@ -136,9 +137,9 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat
# Stochastic validation
model.eval()
likelihood.eval()
val_loss = 0.0 # Initialize val_loss here
with torch.no_grad():
val_indices = torch.randperm(len(val_loader.dataset))[:int(0.1 * len(val_loader.dataset))]
val_loss = 0.0
val_labels = []
val_predictions = []
for idx in val_indices:
Expand Down Expand Up @@ -212,6 +213,7 @@ def evaluate_gp_model(test_loader, model, likelihood, n_classes=4):
return metrics

def main():
print("Step 1: Loading paths and parameters")
# Paths
base_path = r"\\grove.ad.uconn.edu\\research\\ENGR_Chon\Darren\\NIH_Pulsewatch"
smote_type = 'Cassey5k_SMOTE'
Expand All @@ -231,30 +233,45 @@ def main():
else:
n_classes = 3
patience = round(n_epochs / 10) if n_epochs > 50 else 5
save = True
resume_checkpoint_path = None
batch_size = 256

print("Step 2: Loading data")
# Data loading
train_loader = load_data(data_path_train, labels_path_train, batch_size, binary)
val_loader = load_data(data_path_val, labels_path_val, batch_size, binary)
test_loader = load_data(data_path_test, labels_path_test, batch_size, binary)

# Training and validation
print("Step 3: Loading data checkpoints")
# Data loading with checkpointing
data_checkpoint_path = 'data_checkpoint.pt'
if os.path.exists(data_checkpoint_path):
train_loader.dataset.load_checkpoint(data_checkpoint_path)
val_loader.dataset.load_checkpoint(data_checkpoint_path)
test_loader.dataset.load_checkpoint(data_checkpoint_path)

print("Step 4: Training and validation")
# Training and validation with checkpointing and plotting
model_checkpoint_path = 'model_checkpoint.pt'
plot_path = 'training_plot.png'
start_time = time.time()
model, likelihood, metrics = train_gp_model(train_loader, val_loader, n_epochs,
n_classes, patience, save)
n_classes, patience,
model_checkpoint_path, data_checkpoint_path,
resume_checkpoint_path is not None, plot_path)
end_time = time.time()
time_passed = end_time - start_time
print('\nTraining and validation took %.2f minutes' % (time_passed / 60))

print("Step 5: Evaluation")
# Evaluation
start_time = time.time()
test_metrics = evaluate_gp_model(test_loader, model, likelihood, n_classes)
end_time = time.time()
time_passed = end_time - start_time
print('\nTesting took %.2f seconds' % time_passed)

print("Step 6: Printing test metrics")
print('Test Metrics:')
print('Precision: %.4f' % test_metrics['precision'])
print('Recall: %.4f' % test_metrics['recall'])
Expand Down

0 comments on commit 5de1526

Please sign in to comment.