Skip to content

Commit

Permalink
Merge pull request #19 from lrm22005/Luis
Browse files Browse the repository at this point in the history
checkpointing changes
  • Loading branch information
lrm22005 committed Feb 5, 2024
2 parents fd3d77e + 3c61965 commit 8462e1a
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 66 deletions.
96 changes: 59 additions & 37 deletions BML_project/active_learning/ss_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,25 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def label_samples(uncertain_samples, validation_data):
labels = [validation_data[sample_id]['label'] for sample_id in uncertain_samples]
return uncertain_samples, labels
def label_samples(uncertain_samples, validation_dataset):
labeled_samples = [(sample, validation_dataset[sample]['label']) for sample in uncertain_samples]
return labeled_samples

def stochastic_uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_batches, n_components=2):
def stochastic_uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_batches, device):
gp_model.eval()
gp_likelihood.eval()
uncertain_sample_indices = []
sampled_batches = random.sample(list(val_loader), n_batches) # Randomly sample n_batches from val_loader
sampled_batches = random.sample(list(val_loader), min(n_batches, len(val_loader))) # Ensure we don't exceed available batches

with torch.no_grad():
for batch in sampled_batches:
# reduced_data = apply_tsne(batch['data'].reshape(batch['data'].size(0), -1), n_components=n_components)
# reduced_data_tensor = torch.Tensor(reduced_data).to(device)
reduced_data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device)
predictions = gp_likelihood(gp_model(reduced_data_tensor))
var = predictions.variance
top_indices = torch.argsort(-var.flatten())[:n_samples]
uncertain_sample_indices.extend(top_indices.cpu().numpy())

return uncertain_sample_indices[:n_samples]
data = batch['data'].to(device)
predictions = gp_likelihood(gp_model(data))
variances = predictions.variance
_, top_uncertain_indices = torch.topk(variances.view(-1), n_samples)
uncertain_sample_indices.extend(top_uncertain_indices.cpu().numpy())

return list(set(uncertain_sample_indices[:n_samples])) # Remove duplicates and slice if needed

# def uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_components=2):
# gp_model.eval()
Expand Down Expand Up @@ -87,34 +85,58 @@ def stochastic_compare_kmeans_gp_predictions(kmeans_model, gp_model, data_loader

import random

def refined_uncertainty_sampling(gp_model, gp_likelihood, kmeans_model, data_loader, n_samples, n_batches, uncertainty_threshold=0.2):
gp_model.eval()
gp_likelihood.eval()
uncertain_sample_indices = []
# def refined_uncertainty_sampling(gp_model, gp_likelihood, kmeans_model, data_loader, n_samples, n_batches, uncertainty_threshold=0.2):
# gp_model.eval()
# gp_likelihood.eval()
# uncertain_sample_indices = []

# Calculate the total number of batches in the DataLoader
total_batches = len(data_loader)
# # Calculate the total number of batches in the DataLoader
# total_batches = len(data_loader)

# Ensure that n_batches does not exceed total_batches
n_batches = min(n_batches, total_batches)
# # Ensure that n_batches does not exceed total_batches
# n_batches = min(n_batches, total_batches)

# Randomly sample n_batches from data_loader
sampled_batches = random.sample(list(data_loader), n_batches)
# # Randomly sample n_batches from data_loader
# sampled_batches = random.sample(list(data_loader), n_batches)

# with torch.no_grad():
# for batch in sampled_batches:
# data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device)
# gp_predictions = gp_likelihood(gp_model(data_tensor))
# kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy())

# # Calculate the difference between K-means and GP predictions
# disagreement = (gp_predictions.mean.argmax(dim=-1).cpu().numpy() != kmeans_predictions).astype(int)

# # Calculate uncertainty based on variance of GP predictions
# uncertainty = gp_predictions.variance.cpu().numpy()

# # Select samples where the disagreement is high and the model is uncertain
# uncertain_indices = np.where((disagreement > 0) & (uncertainty > uncertainty_threshold))[0]
# uncertain_sample_indices.extend(uncertain_indices)

# return uncertain_sample_indices[:n_samples]

def refined_uncertainty_sampling(gp_model, gp_likelihood, kmeans_model, val_loader, n_samples, n_batches, uncertainty_threshold, device):
gp_model.eval()
gp_likelihood.eval()
uncertain_samples = []

sampled_batches = random.sample(list(val_loader), min(n_batches, len(val_loader)))

with torch.no_grad():
for batch in sampled_batches:
data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device)
gp_predictions = gp_likelihood(gp_model(data_tensor))
kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy())
data = batch['data'].to(device)
gp_predictions = gp_likelihood(gp_model(data))
variances = gp_predictions.variance
high_uncertainty_indices = variances.view(-1) > uncertainty_threshold

# Calculate the difference between K-means and GP predictions
disagreement = (gp_predictions.mean.argmax(dim=-1).cpu().numpy() != kmeans_predictions).astype(int)
kmeans_predictions = kmeans_model.predict(data.cpu().numpy())
gp_class_predictions = gp_predictions.mean.argmax(-1).cpu().numpy()

# Calculate uncertainty based on variance of GP predictions
uncertainty = gp_predictions.variance.cpu().numpy()
disagreement = gp_class_predictions != kmeans_predictions
uncertain_indices = np.where(disagreement & high_uncertainty_indices.cpu().numpy())[0]

# Select samples where the disagreement is high and the model is uncertain
uncertain_indices = np.where((disagreement > 0) & (uncertainty > uncertainty_threshold))[0]
uncertain_sample_indices.extend(uncertain_indices)

return uncertain_sample_indices[:n_samples]
uncertain_samples.extend(uncertain_indices)

return uncertain_samples[:n_samples]
50 changes: 35 additions & 15 deletions BML_project/main_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from models.ss_gp_model import MultitaskGPModel, train_gp_model
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions, label_samples
from utils.visualization import plot_comparative_results, plot_training_performance, plot_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -27,10 +27,11 @@ def main():

# Attempt to resume from the last saved batch index if a dataset checkpoint exists
dataset_checkpoint_path = 'dataset_checkpoint.pt'
current_batch_index = 0 # Default to start from beginning if no checkpoint
if os.path.exists(dataset_checkpoint_path):
for loader in [train_loader, val_loader, test_loader]:
loader.dataset.load_checkpoint(dataset_checkpoint_path)
print(f"Resuming from batch index {loader.dataset.get_current_batch_index()}")
checkpoint = torch.load(dataset_checkpoint_path)
current_batch_index = checkpoint.get('current_batch_index', 0)
print(f"Resuming from batch index {current_batch_index}")

kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device)

Expand All @@ -42,7 +43,19 @@ def main():
}

# Initial model training
model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes)
# model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes)

# Call the train_gp_model function with batch_size
model, likelihood, training_metrics = train_gp_model(
train_loader=train_loader,
val_loader=val_loader,
num_iterations=50,
n_classes=n_classes,
patience=10,
checkpoint_path='model_checkpoint_full.pt',
resume_training=True, # or False
batch_size=batch_size
)

# Save the training metrics for future visualization
results['train_loss'].extend(training_metrics['train_loss'])
Expand All @@ -52,23 +65,30 @@ def main():

active_learning_iterations = 10
# Active Learning Iterations
for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True):
for iteration in range(active_learning_iterations):
print(f"Active Learning Iteration: {iteration+1}/{active_learning_iterations}")
# Perform uncertainty sampling to select new samples from the validation set
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5)
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=50, n_batches=5, device=device)

labeled_samples = label_samples(uncertain_sample_indices, val_loader.dataset)

# Update the training loader with uncertain samples
train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size)
train_loader = update_train_loader_with_uncertain_samples(train_loader, labeled_samples, batch_size)

# Optionally, save the dataset state at intervals or after certain conditions
train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None) # Here, manage the index as needed
train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None)

# model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt')

# Re-train the model with the updated training data
model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt')

model, likelihood, val_metrics = train_gp_model(
train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10,
checkpoint_path='model_checkpoint_last.pt', resume_training=True, batch_size=batch_size
)
# Store the validation metrics after each active learning iteration
results['validation_metrics']['precision'].append(val_metrics['precision'])
results['validation_metrics']['recall'].append(val_metrics['recall'])
results['validation_metrics']['f1'].append(val_metrics['f1'])
results['Active_validation_metrics']['precision'].append(val_metrics['precision'])
results['Active_validation_metrics']['recall'].append(val_metrics['recall'])
results['Active_validation_metrics']['f1'].append(val_metrics['f1'])

# Compare K-Means with GP model predictions after retraining
gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device)
Expand All @@ -84,7 +104,7 @@ def main():
plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels)

# Visualization of results
plot_training_performance(results['train_loss'], results['validation_metrics'])
plot_training_performance(results['train_loss'], results['Active_validation_metrics'])
plot_results(results['test_metrics'])

# Print final test metrics
Expand Down
9 changes: 6 additions & 3 deletions BML_project/models/ss_gp_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,24 @@ def forward(self, x):

return latent_pred

def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False):
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False, batch_size=1024):
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset))

# Load checkpoint if resuming training
start_epoch = 0
current_batch_index = 0 # Default value in case it's not found in the checkpoint
if resume_training and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch
print(f"Resuming training from epoch {start_epoch}")

current_batch_index = checkpoint.get('current_batch_index', 0) # Retrieve the last batch index if it exists
print(f"Resuming training from epoch {start_epoch}, batch index {current_batch_index}")

best_val_loss = float('inf')
epochs_no_improve = 0
metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []}
Expand Down Expand Up @@ -150,6 +152,7 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
'current_batch_index': train_loader.dataset.get_current_batch_index() # Get and save the current batch index
# Include other metrics as needed
}, checkpoint_path)

Expand Down
27 changes: 16 additions & 11 deletions BML_project/utils_gp/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def split_uids():
return clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled

class CustomDataset(Dataset):
def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False):
def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False, start_idx=0):
self.data_path = data_path
self.labels_path = labels_path
self.UIDs = UIDs
Expand All @@ -104,6 +104,7 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='
self.read_all_labels = read_all_labels
self.transforms = ToTensor()
self.refresh_dataset()
self.start_idx = start_idx # Add this line

# Initialize the current batch index to None
self.current_batch_index = None
Expand Down Expand Up @@ -143,12 +144,13 @@ def load_checkpoint(self, checkpoint_path):
self.current_batch_index = checkpoint.get('current_batch_index')

def __getitem__(self, idx):
segment_name = self.segment_names[idx]
actual_idx = idx + self.start_idx
segment_name = self.segment_names[actual_idx]
label = self.labels[segment_name]

if hasattr(self, 'all_data') and idx < len(self.all_data):
if hasattr(self, 'all_data') and actual_idx < len(self.all_data):
# Data is stored in memory
time_freq_tensor = self.all_data[idx]
time_freq_tensor = self.all_data[actual_idx]
else:
# Load data on-the-fly based on the segment_name
time_freq_tensor = self.load_data(segment_name)
Expand All @@ -163,6 +165,9 @@ def set_current_batch_index(self, index):
def get_current_batch_index(self):
return self.current_batch_index

def set_start_idx(self, index):
self.start_idx = index

def add_data_label_pair(self, data, label):
# Assign a unique ID or name for the new data
new_id = len(self.segment_names)
Expand Down Expand Up @@ -232,8 +237,8 @@ def standard_scaling(self, data):
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)
return torch.Tensor(data)

def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=False, drop_last=False, num_workers=4):
dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels)
def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=False, drop_last=False, num_workers=4, start_idx=0):
dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels, start_idx=start_idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=num_workers, prefetch_factor=2)
return dataloader

Expand Down Expand Up @@ -268,12 +273,12 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False):
return data_path, labels_path, saving_path

# Function to extract and preprocess data
def preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size, read_all_labels=False):
# Extracts paths and loads data into train, validation, and test loaders
def preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size, read_all_labels=False, current_batch_index=0):
start_idx = current_batch_index * batch_size
data_path, labels_path, saving_path = get_data_paths(data_format)
train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels)
val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels)
test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels)
train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels, start_idx=start_idx)
val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels, start_idx=start_idx)
test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels, start_idx=start_idx)
return train_loader, val_loader, test_loader

def map_samples_to_uids(uncertain_sample_indices, dataset):
Expand Down

0 comments on commit 8462e1a

Please sign in to comment.