Skip to content

checkpointing changes #19

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 59 additions & 37 deletions BML_project/active_learning/ss_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,25 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def label_samples(uncertain_samples, validation_data):
labels = [validation_data[sample_id]['label'] for sample_id in uncertain_samples]
return uncertain_samples, labels
def label_samples(uncertain_samples, validation_dataset):
labeled_samples = [(sample, validation_dataset[sample]['label']) for sample in uncertain_samples]
return labeled_samples

def stochastic_uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_batches, n_components=2):
def stochastic_uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_batches, device):
gp_model.eval()
gp_likelihood.eval()
uncertain_sample_indices = []
sampled_batches = random.sample(list(val_loader), n_batches) # Randomly sample n_batches from val_loader
sampled_batches = random.sample(list(val_loader), min(n_batches, len(val_loader))) # Ensure we don't exceed available batches

with torch.no_grad():
for batch in sampled_batches:
# reduced_data = apply_tsne(batch['data'].reshape(batch['data'].size(0), -1), n_components=n_components)
# reduced_data_tensor = torch.Tensor(reduced_data).to(device)
reduced_data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device)
predictions = gp_likelihood(gp_model(reduced_data_tensor))
var = predictions.variance
top_indices = torch.argsort(-var.flatten())[:n_samples]
uncertain_sample_indices.extend(top_indices.cpu().numpy())

return uncertain_sample_indices[:n_samples]
data = batch['data'].to(device)
predictions = gp_likelihood(gp_model(data))
variances = predictions.variance
_, top_uncertain_indices = torch.topk(variances.view(-1), n_samples)
uncertain_sample_indices.extend(top_uncertain_indices.cpu().numpy())

return list(set(uncertain_sample_indices[:n_samples])) # Remove duplicates and slice if needed

# def uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_components=2):
# gp_model.eval()
Expand Down Expand Up @@ -87,34 +85,58 @@ def stochastic_compare_kmeans_gp_predictions(kmeans_model, gp_model, data_loader

import random

def refined_uncertainty_sampling(gp_model, gp_likelihood, kmeans_model, data_loader, n_samples, n_batches, uncertainty_threshold=0.2):
gp_model.eval()
gp_likelihood.eval()
uncertain_sample_indices = []
# def refined_uncertainty_sampling(gp_model, gp_likelihood, kmeans_model, data_loader, n_samples, n_batches, uncertainty_threshold=0.2):
# gp_model.eval()
# gp_likelihood.eval()
# uncertain_sample_indices = []

# Calculate the total number of batches in the DataLoader
total_batches = len(data_loader)
# # Calculate the total number of batches in the DataLoader
# total_batches = len(data_loader)

# Ensure that n_batches does not exceed total_batches
n_batches = min(n_batches, total_batches)
# # Ensure that n_batches does not exceed total_batches
# n_batches = min(n_batches, total_batches)

# Randomly sample n_batches from data_loader
sampled_batches = random.sample(list(data_loader), n_batches)
# # Randomly sample n_batches from data_loader
# sampled_batches = random.sample(list(data_loader), n_batches)

# with torch.no_grad():
# for batch in sampled_batches:
# data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device)
# gp_predictions = gp_likelihood(gp_model(data_tensor))
# kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy())

# # Calculate the difference between K-means and GP predictions
# disagreement = (gp_predictions.mean.argmax(dim=-1).cpu().numpy() != kmeans_predictions).astype(int)

# # Calculate uncertainty based on variance of GP predictions
# uncertainty = gp_predictions.variance.cpu().numpy()

# # Select samples where the disagreement is high and the model is uncertain
# uncertain_indices = np.where((disagreement > 0) & (uncertainty > uncertainty_threshold))[0]
# uncertain_sample_indices.extend(uncertain_indices)

# return uncertain_sample_indices[:n_samples]

def refined_uncertainty_sampling(gp_model, gp_likelihood, kmeans_model, val_loader, n_samples, n_batches, uncertainty_threshold, device):
gp_model.eval()
gp_likelihood.eval()
uncertain_samples = []

sampled_batches = random.sample(list(val_loader), min(n_batches, len(val_loader)))

with torch.no_grad():
for batch in sampled_batches:
data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device)
gp_predictions = gp_likelihood(gp_model(data_tensor))
kmeans_predictions = kmeans_model.predict(data_tensor.cpu().numpy())
data = batch['data'].to(device)
gp_predictions = gp_likelihood(gp_model(data))
variances = gp_predictions.variance
high_uncertainty_indices = variances.view(-1) > uncertainty_threshold

# Calculate the difference between K-means and GP predictions
disagreement = (gp_predictions.mean.argmax(dim=-1).cpu().numpy() != kmeans_predictions).astype(int)
kmeans_predictions = kmeans_model.predict(data.cpu().numpy())
gp_class_predictions = gp_predictions.mean.argmax(-1).cpu().numpy()

# Calculate uncertainty based on variance of GP predictions
uncertainty = gp_predictions.variance.cpu().numpy()
disagreement = gp_class_predictions != kmeans_predictions
uncertain_indices = np.where(disagreement & high_uncertainty_indices.cpu().numpy())[0]

# Select samples where the disagreement is high and the model is uncertain
uncertain_indices = np.where((disagreement > 0) & (uncertainty > uncertainty_threshold))[0]
uncertain_sample_indices.extend(uncertain_indices)

return uncertain_sample_indices[:n_samples]
uncertain_samples.extend(uncertain_indices)

return uncertain_samples[:n_samples]
50 changes: 35 additions & 15 deletions BML_project/main_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from models.ss_gp_model import MultitaskGPModel, train_gp_model
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions, label_samples
from utils.visualization import plot_comparative_results, plot_training_performance, plot_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -27,10 +27,11 @@ def main():

# Attempt to resume from the last saved batch index if a dataset checkpoint exists
dataset_checkpoint_path = 'dataset_checkpoint.pt'
current_batch_index = 0 # Default to start from beginning if no checkpoint
if os.path.exists(dataset_checkpoint_path):
for loader in [train_loader, val_loader, test_loader]:
loader.dataset.load_checkpoint(dataset_checkpoint_path)
print(f"Resuming from batch index {loader.dataset.get_current_batch_index()}")
checkpoint = torch.load(dataset_checkpoint_path)
current_batch_index = checkpoint.get('current_batch_index', 0)
print(f"Resuming from batch index {current_batch_index}")

kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device)

Expand All @@ -42,7 +43,19 @@ def main():
}

# Initial model training
model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes)
# model, likelihood, training_metrics = train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=n_classes)

# Call the train_gp_model function with batch_size
model, likelihood, training_metrics = train_gp_model(
train_loader=train_loader,
val_loader=val_loader,
num_iterations=50,
n_classes=n_classes,
patience=10,
checkpoint_path='model_checkpoint_full.pt',
resume_training=True, # or False
batch_size=batch_size
)

# Save the training metrics for future visualization
results['train_loss'].extend(training_metrics['train_loss'])
Expand All @@ -52,23 +65,30 @@ def main():

active_learning_iterations = 10
# Active Learning Iterations
for iteration in tqdm.tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True):
for iteration in range(active_learning_iterations):
print(f"Active Learning Iteration: {iteration+1}/{active_learning_iterations}")
# Perform uncertainty sampling to select new samples from the validation set
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=batch_size, n_batches=5)
uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples=50, n_batches=5, device=device)

labeled_samples = label_samples(uncertain_sample_indices, val_loader.dataset)

# Update the training loader with uncertain samples
train_loader = update_train_loader_with_uncertain_samples(train_loader, uncertain_sample_indices, batch_size)
train_loader = update_train_loader_with_uncertain_samples(train_loader, labeled_samples, batch_size)

# Optionally, save the dataset state at intervals or after certain conditions
train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None) # Here, manage the index as needed
train_loader.dataset.save_checkpoint(dataset_checkpoint_path, current_batch_index=None)

# model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt')

# Re-train the model with the updated training data
model, likelihood, val_metrics = train_gp_model(train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10, checkpoint_path='model_checkpoint_last.pt')

model, likelihood, val_metrics = train_gp_model(
train_loader, val_loader, num_iterations=10, n_classes=n_classes, patience=10,
checkpoint_path='model_checkpoint_last.pt', resume_training=True, batch_size=batch_size
)
# Store the validation metrics after each active learning iteration
results['validation_metrics']['precision'].append(val_metrics['precision'])
results['validation_metrics']['recall'].append(val_metrics['recall'])
results['validation_metrics']['f1'].append(val_metrics['f1'])
results['Active_validation_metrics']['precision'].append(val_metrics['precision'])
results['Active_validation_metrics']['recall'].append(val_metrics['recall'])
results['Active_validation_metrics']['f1'].append(val_metrics['f1'])

# Compare K-Means with GP model predictions after retraining
gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device)
Expand All @@ -84,7 +104,7 @@ def main():
plot_comparative_results(test_gp_vs_kmeans_data, test_original_labels)

# Visualization of results
plot_training_performance(results['train_loss'], results['validation_metrics'])
plot_training_performance(results['train_loss'], results['Active_validation_metrics'])
plot_results(results['test_metrics'])

# Print final test metrics
Expand Down
9 changes: 6 additions & 3 deletions BML_project/models/ss_gp_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,24 @@ def forward(self, x):

return latent_pred

def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False):
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False, batch_size=1024):
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset))

# Load checkpoint if resuming training
start_epoch = 0
current_batch_index = 0 # Default value in case it's not found in the checkpoint
if resume_training and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch
print(f"Resuming training from epoch {start_epoch}")

current_batch_index = checkpoint.get('current_batch_index', 0) # Retrieve the last batch index if it exists
print(f"Resuming training from epoch {start_epoch}, batch index {current_batch_index}")

best_val_loss = float('inf')
epochs_no_improve = 0
metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []}
Expand Down Expand Up @@ -150,6 +152,7 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss,
'current_batch_index': train_loader.dataset.get_current_batch_index() # Get and save the current batch index
# Include other metrics as needed
}, checkpoint_path)

Expand Down
27 changes: 16 additions & 11 deletions BML_project/utils_gp/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def split_uids():
return clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled

class CustomDataset(Dataset):
def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False):
def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False, start_idx=0):
self.data_path = data_path
self.labels_path = labels_path
self.UIDs = UIDs
Expand All @@ -104,6 +104,7 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='
self.read_all_labels = read_all_labels
self.transforms = ToTensor()
self.refresh_dataset()
self.start_idx = start_idx # Add this line

# Initialize the current batch index to None
self.current_batch_index = None
Expand Down Expand Up @@ -143,12 +144,13 @@ def load_checkpoint(self, checkpoint_path):
self.current_batch_index = checkpoint.get('current_batch_index')

def __getitem__(self, idx):
segment_name = self.segment_names[idx]
actual_idx = idx + self.start_idx
segment_name = self.segment_names[actual_idx]
label = self.labels[segment_name]

if hasattr(self, 'all_data') and idx < len(self.all_data):
if hasattr(self, 'all_data') and actual_idx < len(self.all_data):
# Data is stored in memory
time_freq_tensor = self.all_data[idx]
time_freq_tensor = self.all_data[actual_idx]
else:
# Load data on-the-fly based on the segment_name
time_freq_tensor = self.load_data(segment_name)
Expand All @@ -163,6 +165,9 @@ def set_current_batch_index(self, index):
def get_current_batch_index(self):
return self.current_batch_index

def set_start_idx(self, index):
self.start_idx = index

def add_data_label_pair(self, data, label):
# Assign a unique ID or name for the new data
new_id = len(self.segment_names)
Expand Down Expand Up @@ -232,8 +237,8 @@ def standard_scaling(self, data):
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)
return torch.Tensor(data)

def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=False, drop_last=False, num_workers=4):
dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels)
def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=False, drop_last=False, num_workers=4, start_idx=0):
dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels, start_idx=start_idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=num_workers, prefetch_factor=2)
return dataloader

Expand Down Expand Up @@ -268,12 +273,12 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False):
return data_path, labels_path, saving_path

# Function to extract and preprocess data
def preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size, read_all_labels=False):
# Extracts paths and loads data into train, validation, and test loaders
def preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size, read_all_labels=False, current_batch_index=0):
start_idx = current_batch_index * batch_size
data_path, labels_path, saving_path = get_data_paths(data_format)
train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels)
val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels)
test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels)
train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels, start_idx=start_idx)
val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels, start_idx=start_idx)
test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format=data_format, read_all_labels=read_all_labels, start_idx=start_idx)
return train_loader, val_loader, test_loader

def map_samples_to_uids(uncertain_sample_indices, dataset):
Expand Down