diff --git a/BML_project/models/ss_gp_model_checkpoint.py b/BML_project/models/ss_gp_model_checkpoint.py index 282c73d..c11f8cb 100644 --- a/BML_project/models/ss_gp_model_checkpoint.py +++ b/BML_project/models/ss_gp_model_checkpoint.py @@ -69,30 +69,40 @@ def forward(self, x): return latent_pred -def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False, batch_size=1024): +def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_checkpoint_path=None): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = MultitaskGPModel().to(device) likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.1) mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset)) + best_val_loss = float('inf') + epochs_no_improve = 0 - # Load checkpoint if resuming training start_epoch = 0 - current_batch_index = 0 # Default value in case it's not found in the checkpoint - if resume_training and os.path.exists(checkpoint_path): - checkpoint = torch.load(checkpoint_path) + start_batch = 0 + + # Resume from checkpoint if specified + if resume_checkpoint_path is not None and os.path.exists(resume_checkpoint_path): + checkpoint = torch.load(resume_checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) likelihood.load_state_dict(checkpoint['likelihood_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch - current_batch_index = checkpoint.get('current_batch_index', 0) # Retrieve the last batch index if it exists - print(f"Resuming training from epoch {start_epoch}, batch index {current_batch_index}") - - best_val_loss = float('inf') - epochs_no_improve = 0 - metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []} - - for epoch in range(start_epoch, num_iterations): + start_epoch = checkpoint.get('epoch', 0) + start_batch = checkpoint.get('batch', 0) + print(f"Resuming training from epoch {start_epoch}, batch {start_batch}") + + metrics = { + 'precision': [], + 'recall': [], + 'f1_score': [], + 'train_loss': [] + } + + for epoch in tqdm.tqdm(range(start_epoch, num_iterations), desc='Training', unit='epoch', leave=False): for batch_index, train_batch in enumerate(train_loader): + if epoch == start_epoch and batch_index < start_batch: + continue # Skip batches until the saved batch index + model.train() likelihood.train() optimizer.zero_grad() @@ -104,6 +114,17 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat loss.backward() optimizer.step() + # Save checkpoint at intervals or based on other conditions + if (batch_index + 1) % 100 == 0: # Example condition + torch.save({ + 'epoch': epoch, + 'batch': batch_index, + 'model_state_dict': model.state_dict(), + 'likelihood_state_dict': likelihood.state_dict(), + 'optimizer_state_dict': optimizer.state_dict() + }, checkpoint_path) + print(f"Checkpoint saved at epoch {epoch}, batch {batch_index}") + # Stochastic validation model.eval() likelihood.eval() @@ -160,7 +181,7 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat print(f"Early stopping triggered at epoch {epoch+1}") break - # Optionally, load the best model at the end of training + # Ensure to load the latest model state after the loop in case of early stopping if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) diff --git a/BML_project/utils_gp/data_loader.py b/BML_project/utils_gp/data_loader.py index 67169b1..fe3bc7e 100644 --- a/BML_project/utils_gp/data_loader.py +++ b/BML_project/utils_gp/data_loader.py @@ -103,34 +103,30 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format=' self.data_format = data_format self.read_all_labels = read_all_labels self.transforms = ToTensor() + self.start_idx = start_idx # Initial batch index to start from, useful for resuming training self.refresh_dataset() - self.start_idx = start_idx # Add this line - # Initialize the current batch index to None + # Initialize the current batch index to None, this could be used if you want to track batch progress within the dataset itself self.current_batch_index = None def refresh_dataset(self): - # Extract unique segment names and their corresponding labels self.segment_names, self.labels = self.extract_segment_names_and_labels() def add_uids(self, new_uids): - # Ensure new UIDs are unique and not already in the dataset unique_new_uids = [uid for uid in new_uids if uid not in self.UIDs] - - # Add unique new UIDs and refresh the dataset self.UIDs.extend(unique_new_uids) self.refresh_dataset() def __len__(self): return len(self.segment_names) - - def save_checkpoint(self, checkpoint_path, current_batch_index=None): + + def save_checkpoint(self, checkpoint_path): + # Enhanced to automatically include 'start_idx' in the checkpoint checkpoint = { 'segment_names': self.segment_names, 'labels': self.labels, 'UIDs': self.UIDs, - # Save the current batch index if provided - 'current_batch_index': current_batch_index if current_batch_index is not None else self.current_batch_index + 'start_idx': self.start_idx # Now also saving start_idx } torch.save(checkpoint, checkpoint_path) @@ -139,32 +135,28 @@ def load_checkpoint(self, checkpoint_path): self.segment_names = checkpoint['segment_names'] self.labels = checkpoint['labels'] self.UIDs = checkpoint['UIDs'] + # Now also loading and setting start_idx from checkpoint + self.start_idx = checkpoint.get('start_idx', 0) self.refresh_dataset() - # Load the current batch index if it exists in the checkpoint - self.current_batch_index = checkpoint.get('current_batch_index') def __getitem__(self, idx): - actual_idx = idx + self.start_idx + actual_idx = (idx + self.start_idx) % len(self.segment_names) # Adjust index based on start_idx and wrap around if needed segment_name = self.segment_names[actual_idx] label = self.labels[segment_name] if hasattr(self, 'all_data') and actual_idx < len(self.all_data): - # Data is stored in memory time_freq_tensor = self.all_data[actual_idx] else: - # Load data on-the-fly based on the segment_name time_freq_tensor = self.load_data(segment_name) return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name} - # New method to set the current batch index def set_current_batch_index(self, index): self.current_batch_index = index - # New method to get the current batch index def get_current_batch_index(self): return self.current_batch_index - + def set_start_idx(self, index): self.start_idx = index