Skip to content

Commit

Permalink
Update for machine learning way
Browse files Browse the repository at this point in the history
  • Loading branch information
Luis Roberto Mercado Diaz committed Feb 7, 2024
1 parent 8462e1a commit 3881e71
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
51 changes: 36 additions & 15 deletions BML_project/models/ss_gp_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,40 @@ def forward(self, x):

return latent_pred

def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_training=False, batch_size=1024):
def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt', resume_checkpoint_path=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultitaskGPModel().to(device)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_loader.dataset))
best_val_loss = float('inf')
epochs_no_improve = 0

# Load checkpoint if resuming training
start_epoch = 0
current_batch_index = 0 # Default value in case it's not found in the checkpoint
if resume_training and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
start_batch = 0

# Resume from checkpoint if specified
if resume_checkpoint_path is not None and os.path.exists(resume_checkpoint_path):
checkpoint = torch.load(resume_checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
likelihood.load_state_dict(checkpoint['likelihood_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the next epoch
current_batch_index = checkpoint.get('current_batch_index', 0) # Retrieve the last batch index if it exists
print(f"Resuming training from epoch {start_epoch}, batch index {current_batch_index}")

best_val_loss = float('inf')
epochs_no_improve = 0
metrics = {'precision': [], 'recall': [], 'f1_score': [], 'auc_roc': [], 'train_loss': []}

for epoch in range(start_epoch, num_iterations):
start_epoch = checkpoint.get('epoch', 0)
start_batch = checkpoint.get('batch', 0)
print(f"Resuming training from epoch {start_epoch}, batch {start_batch}")

metrics = {
'precision': [],
'recall': [],
'f1_score': [],
'train_loss': []
}

for epoch in tqdm.tqdm(range(start_epoch, num_iterations), desc='Training', unit='epoch', leave=False):
for batch_index, train_batch in enumerate(train_loader):
if epoch == start_epoch and batch_index < start_batch:
continue # Skip batches until the saved batch index

model.train()
likelihood.train()
optimizer.zero_grad()
Expand All @@ -104,6 +114,17 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat
loss.backward()
optimizer.step()

# Save checkpoint at intervals or based on other conditions
if (batch_index + 1) % 100 == 0: # Example condition
torch.save({
'epoch': epoch,
'batch': batch_index,
'model_state_dict': model.state_dict(),
'likelihood_state_dict': likelihood.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, checkpoint_path)
print(f"Checkpoint saved at epoch {epoch}, batch {batch_index}")

# Stochastic validation
model.eval()
likelihood.eval()
Expand Down Expand Up @@ -160,7 +181,7 @@ def train_gp_model(train_loader, val_loader, num_iterations=50, n_classes=4, pat
print(f"Early stopping triggered at epoch {epoch+1}")
break

# Optionally, load the best model at the end of training
# Ensure to load the latest model state after the loop in case of early stopping
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
Expand Down
28 changes: 10 additions & 18 deletions BML_project/utils_gp/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,34 +103,30 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='
self.data_format = data_format
self.read_all_labels = read_all_labels
self.transforms = ToTensor()
self.start_idx = start_idx # Initial batch index to start from, useful for resuming training
self.refresh_dataset()
self.start_idx = start_idx # Add this line

# Initialize the current batch index to None
# Initialize the current batch index to None, this could be used if you want to track batch progress within the dataset itself
self.current_batch_index = None

def refresh_dataset(self):
# Extract unique segment names and their corresponding labels
self.segment_names, self.labels = self.extract_segment_names_and_labels()

def add_uids(self, new_uids):
# Ensure new UIDs are unique and not already in the dataset
unique_new_uids = [uid for uid in new_uids if uid not in self.UIDs]

# Add unique new UIDs and refresh the dataset
self.UIDs.extend(unique_new_uids)
self.refresh_dataset()

def __len__(self):
return len(self.segment_names)

def save_checkpoint(self, checkpoint_path, current_batch_index=None):

def save_checkpoint(self, checkpoint_path):
# Enhanced to automatically include 'start_idx' in the checkpoint
checkpoint = {
'segment_names': self.segment_names,
'labels': self.labels,
'UIDs': self.UIDs,
# Save the current batch index if provided
'current_batch_index': current_batch_index if current_batch_index is not None else self.current_batch_index
'start_idx': self.start_idx # Now also saving start_idx
}
torch.save(checkpoint, checkpoint_path)

Expand All @@ -139,32 +135,28 @@ def load_checkpoint(self, checkpoint_path):
self.segment_names = checkpoint['segment_names']
self.labels = checkpoint['labels']
self.UIDs = checkpoint['UIDs']
# Now also loading and setting start_idx from checkpoint
self.start_idx = checkpoint.get('start_idx', 0)
self.refresh_dataset()
# Load the current batch index if it exists in the checkpoint
self.current_batch_index = checkpoint.get('current_batch_index')

def __getitem__(self, idx):
actual_idx = idx + self.start_idx
actual_idx = (idx + self.start_idx) % len(self.segment_names) # Adjust index based on start_idx and wrap around if needed
segment_name = self.segment_names[actual_idx]
label = self.labels[segment_name]

if hasattr(self, 'all_data') and actual_idx < len(self.all_data):
# Data is stored in memory
time_freq_tensor = self.all_data[actual_idx]
else:
# Load data on-the-fly based on the segment_name
time_freq_tensor = self.load_data(segment_name)

return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name}

# New method to set the current batch index
def set_current_batch_index(self, index):
self.current_batch_index = index

# New method to get the current batch index
def get_current_batch_index(self):
return self.current_batch_index

def set_start_idx(self, index):
self.start_idx = index

Expand Down

0 comments on commit 3881e71

Please sign in to comment.