Skip to content

Luis #21

Merged
merged 3 commits into from
Apr 18, 2024
Merged

Luis #21

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
28 changes: 27 additions & 1 deletion BML_project/active_learning/ss_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,30 @@ def run_minibatch_kmeans(data_loader, n_clusters, device, batch_size=100):

return minibatch_kmeans

""" from sklearn.cluster import MiniBatchKMeans
import numpy as np
def run_minibatch_kmeans(data_loader, n_clusters, device, batch_size=100, n_init='auto'):
# Initialize MiniBatchKMeans with explicit n_init to suppress the FutureWarning
minibatch_kmeans = MiniBatchKMeans(n_clusters=n_clusters, n_init=n_init, random_state=0, batch_size=batch_size)
# Prepare an empty list to collect all data for fitting
all_data = []
# Iterate through data_loader and collect data
for batch in data_loader:
# Assuming 'data' is a key in your batch dict that contains the features
data = batch['data'].view(batch['data'].size(0), -1).cpu().numpy() # Adjust as necessary
all_data.append(data)
# Concatenate all data collected from the batches
all_data_np = np.concatenate(all_data, axis=0)
# Fit MiniBatchKMeans with all collected data at once
minibatch_kmeans.fit(all_data_np)
return minibatch_kmeans """

# def compare_kmeans_gp_predictions(kmeans_model, gp_model, data_loader, device):
# # Compare K-Means with GP model predictions
# all_data, all_labels = [], []
Expand All @@ -80,7 +104,9 @@ def stochastic_compare_kmeans_gp_predictions(kmeans_model, gp_model, data_loader
kmeans_predictions = kmeans_model.predict(data.cpu().numpy())
all_labels.append(labels.cpu().numpy())
all_data.append((gp_predictions, kmeans_predictions))

print(f"Processed batch size: {len(current_batch_labels)}, Cumulative original_labels size: {len(original_labels)}, Cumulative gp_predictions size: {len(gp_predictions)}")
if len(current_batch_labels) < expected_batch_size:
print(f"Last batch processed with size: {len(current_batch_labels)}")
return all_data, np.concatenate(all_labels)

import random
Expand Down
64 changes: 64 additions & 0 deletions BML_project/main_checkpoints_updated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
"""
Created on Wed Feb 7 15:34:31 2024
@author: lrm22005
"""
import os
import tqdm
import torch
from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from models.ss_gp_model import MultitaskGPModel, train_gp_model
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions, label_samples
from utils.visualization import plot_comparative_results, plot_training_performance, plot_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main():
n_classes = 4
batch_size = 1024
clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled = split_uids()
data_format = 'pt'

train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size)

# Initialize result storage
results = {
'train_loss': [],
'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []},
'active_learning': {'validation_metrics': []}, # Store validation metrics for each active learning iteration
'test_metrics': None
}

# Initial model training
model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_full.pt')

# Save initial training metrics
results['train_loss'].extend(training_metrics['train_loss'])
for metric in ['precision', 'recall', 'f1_score']:
results['validation_metrics'][metric].extend(training_metrics[metric])

active_learning_iterations = 10
for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration'):
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, device=device)
train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size)

# Re-train the model with updated training data
model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt')

# Store validation metrics for each active learning iteration
results['active_learning']['validation_metrics'].append(val_metrics)

# Final evaluations
test_metrics = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes)
results['test_metrics'] = test_metrics

# Visualization of results
plot_training_performance(results['train_loss'], results['validation_metrics'])
plot_results(results['test_metrics']) # Adjust this function to handle the structure of test_metrics

print("Final Test Metrics:", results['test_metrics'])

if __name__ == "__main__":
main()
Binary file modified BML_project/models/__pycache__/ss_gp_model.cpython-311.pyc
Binary file not shown.
51 changes: 36 additions & 15 deletions BML_project/models/ss_gp_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,40 @@ def forward(self, x):

return latent_pred

def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False, batch_size=1024):
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_checkpoint_path=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset))
best_val_loss = float('inf')
epochs_no_improve = 0

# Load checkpoint if resuming training
start_epoch = 0
current_batch_index = 0 # Default value in case it's not found in the checkpoint
if resume_training and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
start_batch = 0

# Resume from checkpoint if specified
if resume_checkpoint_path is not None and os.path.exists(resume_checkpoint_path):
checkpoint = torch.load(resume_checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch
current_batch_index = checkpoint.get('current_batch_index', 0) # Retrieve the last batch index if it exists
print(f"Resuming training from epoch {start_epoch}, batch index {current_batch_index}")

best_val_loss = float('inf')
epochs_no_improve = 0
metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []}

for epoch in range(start_epoch, num_iterations):
start_epoch = checkpoint.get('epoch', 0)
start_batch = checkpoint.get('batch', 0)
print(f"Resuming training from epoch {start_epoch}, batch {start_batch}")

metrics = {
'precision': [],
'recall': [],
'f1_score': [],
'train_loss': []
}

for epoch in tqdm.tqdm(range(start_epoch, num_iterations), desc='Training', unit='epoch', leave=False):
for batch_index, train_batch in enumerate(train_loader):
if epoch == start_epoch and batch_index < start_batch:
continue # Skip batches until the saved batch index

model.train()
likelihood.train()
optimizer.zero_grad()
Expand All @@ -104,6 +114,17 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat
loss.backward()
optimizer.step()

# Save checkpoint at intervals or based on other conditions
if (batch_index + 1) % 100 == 0: # Example condition
torch.save({
'epoch': epoch,
'batch': batch_index,
'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, checkpoint_path)
print(f"Checkpoint saved at epoch {epoch}, batch {batch_index}")

# Stochastic validation
model.eval()
likelihood.eval()
Expand Down Expand Up @@ -160,7 +181,7 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat
print(f"Early stopping triggered at epoch {epoch+1}")
break

# Optionally, load the best model at the end of training
# Ensure to load the latest model state after the loop in case of early stopping
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
Expand Down
7 changes: 5 additions & 2 deletions BML_project/ss_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
"""
import tqdm
import torch
from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from utils_gp.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from models.ss_gp_model import MultitaskGPModel, train_gp_model
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions
from utils.visualization import plot_comparative_results, plot_training_performance, plot_results
from utils_gp.visualization import plot_comparative_results, plot_training_performance, plot_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -50,6 +50,7 @@ def main():

# Update the training loader with uncertain samples
train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size)
print(f"Updated training data size: {len(train_loader.dataset)}")

# Re-train the model with the updated training data
model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt')
Expand All @@ -71,6 +72,8 @@ def main():

results['test_metrics'] = test_metrics
test_gp_vs_kmeans_data, test_original_labels = stochastic_compare_kmeans_gp_predictions(test_kmeans_model, model, test_loader, n_batches=5, device=device)

print(f"Length of original_labels: {len(original_labels)}, Length of gp_predictions: {len(gp_predictions)}")
plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels)

# Visualization of results
Expand Down
Binary file modified BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc
Binary file not shown.
Binary file modified BML_project/utils_gp/__pycache__/ss_evaluation.cpython-311.pyc
Binary file not shown.
Binary file modified BML_project/utils_gp/__pycache__/visualization.cpython-311.pyc
Binary file not shown.
28 changes: 10 additions & 18 deletions BML_project/utils_gp/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,34 +103,30 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='
self.data_format = data_format
self.read_all_labels = read_all_labels
self.transforms = ToTensor()
self.start_idx = start_idx # Initial batch index to start from, useful for resuming training
self.refresh_dataset()
self.start_idx = start_idx # Add this line

# Initialize the current batch index to None
# Initialize the current batch index to None, this could be used if you want to track batch progress within the dataset itself
self.current_batch_index = None

def refresh_dataset(self):
# Extract unique segment names and their corresponding labels
self.segment_names, self.labels = self.extract_segment_names_and_labels()

def add_uids(self, new_uids):
# Ensure new UIDs are unique and not already in the dataset
unique_new_uids = [uid for uid in new_uids if uid not in self.UIDs]

# Add unique new UIDs and refresh the dataset
self.UIDs.extend(unique_new_uids)
self.refresh_dataset()

def __len__(self):
return len(self.segment_names)

def save_checkpoint(self, checkpoint_path, current_batch_index=None):

def save_checkpoint(self, checkpoint_path):
# Enhanced to automatically include 'start_idx' in the checkpoint
checkpoint = {
'segment_names': self.segment_names,
'labels': self.labels,
'UIDs': self.UIDs,
# Save the current batch index if provided
'current_batch_index': current_batch_index if current_batch_index is not None else self.current_batch_index
'start_idx': self.start_idx # Now also saving start_idx
}
torch.save(checkpoint, checkpoint_path)

Expand All @@ -139,32 +135,28 @@ def load_checkpoint(self, checkpoint_path):
self.segment_names = checkpoint['segment_names']
self.labels = checkpoint['labels']
self.UIDs = checkpoint['UIDs']
# Now also loading and setting start_idx from checkpoint
self.start_idx = checkpoint.get('start_idx', 0)
self.refresh_dataset()
# Load the current batch index if it exists in the checkpoint
self.current_batch_index = checkpoint.get('current_batch_index')

def __getitem__(self, idx):
actual_idx = idx + self.start_idx
actual_idx = (idx + self.start_idx) % len(self.segment_names) # Adjust index based on start_idx and wrap around if needed
segment_name = self.segment_names[actual_idx]
label = self.labels[segment_name]

if hasattr(self, 'all_data') and actual_idx < len(self.all_data):
# Data is stored in memory
time_freq_tensor = self.all_data[actual_idx]
else:
# Load data on-the-fly based on the segment_name
time_freq_tensor = self.load_data(segment_name)

return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name}

# New method to set the current batch index
def set_current_batch_index(self, index):
self.current_batch_index = index

# New method to get the current batch index
def get_current_batch_index(self):
return self.current_batch_index

def set_start_idx(self, index):
self.start_idx = index

Expand Down