Skip to content

Commit

Permalink
Adding an example of checkpoint
Browse files Browse the repository at this point in the history
This is a good guide if you like to try using checkpoints over the data loaders.
  • Loading branch information
lrm22005 committed Dec 18, 2023
1 parent 123c76f commit cf71da9
Show file tree
Hide file tree
Showing 4 changed files with 605 additions and 236 deletions.
334 changes: 334 additions & 0 deletions GP_Original_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 18 12:34:01 2023
@author: lrm22005
This approach keeps the checkpointing logic separate from the data loading logic,
which is a good practice in terms of code organization and reusability.
The should_save_checkpoint function is a placeholder for whatever logic you want to use to determine when to save a checkpoint.
It could be based on time, number of batches processed, or any other criterion.
Remember that while this method saves the progress in terms of batch index,
it does not automatically save the state of your model, optimizer, or any other components of your training loop. You should handle those separately as needed.
The code is designed to train a machine learning model (Gaussian Process) using an active learning approach.
It heavily relies on custom functions like load_data_split_batched, train_gp_model, uncertainty_sampling, update_train_loader_with_uncertain_samples, evaluate_model_on_all_data, and plotting functions, which are not defined in this snippet.
Active learning is used to iteratively select the most informative samples to improve the model iteratively.
Checkpoints are used to save the state of the model at each iteration, allowing for recovery and continuation of training in case of interruption.
The code evaluates model performance on both validation and test datasets, storing various metrics for analysis.
"""
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
from sklearn.preprocessing import StandardScaler
from PIL import Image
import pickle

from tqdm import tqdm
from GP_original_data import map_samples_to_uids, MultitaskGPModel, train_gp_model, parse_classification_report
from GP_original_data import evaluate_model_on_all_data, uncertainty_sampling, label_samples
from GP_original_data import update_train_loader_with_uncertain_samples, plot_training_performance, plot_results

def get_data_paths(data_format, is_linux=False, is_hpc=False):
if is_linux:
base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch"
labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn"
saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis"
elif is_hpc:
base_path = "/gpfs/scratchfs1/kic14002/doh16101"
labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005"
saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis"
else:
# R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch
base_path = "R:\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch"
labels_base_path = "R:\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn"
saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case"
if data_format == 'csv':
data_path = os.path.join(base_path, "TFS_csv")
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm")
saving_path = os.path.join(saving_base_path, "Project_1_analysis")
elif data_format == 'png':
data_path = os.path.join(base_path, "TFS_plots")
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm")
saving_path = os.path.join(saving_base_path, "Project_1_analysis")
else:
raise ValueError("Invalid data format. Choose 'csv' or 'png.")
return data_path, labels_path, saving_path

class CustomDataset(Dataset):
def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False):
self.data_path = data_path
self.labels_path = labels_path
self.UIDs = UIDs
self.standardize = standardize
self.data_format = data_format
self.read_all_labels = read_all_labels
self.transforms = ToTensor()
self.refresh_dataset()

def refresh_dataset(self):
# Extract unique segment names and their corresponding labels
self.segment_names, self.labels = self.extract_segment_names_and_labels()

def add_uids(self, new_uids):
# Ensure new UIDs are unique and not already in the dataset
unique_new_uids = [uid for uid in new_uids if uid not in self.UIDs]

# Add unique new UIDs and refresh the dataset
self.UIDs.extend(unique_new_uids)
self.refresh_dataset()

def __len__(self):
return len(self.segment_names)

def __getitem__(self, idx):
segment_name = self.segment_names[idx]
label = self.labels[segment_name]

# Load data on-the-fly based on the segment_name
time_freq_tensor = self.load_data(segment_name)

return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name}

def extract_segment_names_and_labels(self):
segment_names = []
labels = {}

for UID in self.UIDs:
label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv")
if os.path.exists(label_file):
label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label'])
label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0])
for idx, segment_name in enumerate(label_segment_names):
label_val = label_data['label'].values[idx]
if self.read_all_labels:
# Assign -1 if label is not in [0, 1, 2, 3]
labels[segment_name] = label_val if label_val in [0, 1, 2, 3] else -1
if segment_name not in segment_names:
segment_names.append(segment_name)
else:
# Only add segments with labels in [0, 1, 2, 3]
if label_val in [0, 1, 2, 3] and segment_name not in segment_names:
segment_names.append(segment_name)
labels[segment_name] = label_val

return segment_names, labels

def load_data(self, segment_name):
data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0])
seg_path = os.path.join(data_path_UID, segment_name + '_filt_STFT.csv')

try:
if self.data_format == 'csv' and seg_path.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif self.data_format == 'png' and seg_path.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
raise ValueError("Unsupported file format")

if self.standardize:
time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data

return time_freq_tensor.clone()

except Exception as e:
print(f"Error processing segment: {segment_name}. Exception: {str(e)}")
return torch.zeros((1, 128, 128)) # Return zeros in case of an error

def standard_scaling(self, data):
scaler = StandardScaler()
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)
return torch.Tensor(data)

def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=True, drop_last=False):
dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last)
return dataloader

class CheckpointManager:
def __init__(self, checkpoint_dir):
self.checkpoint_dir = checkpoint_dir # Store the directory path for checkpoints
if not os.path.exists(checkpoint_dir): # Check if the directory exists
os.makedirs(checkpoint_dir) # Create the directory if it does not exist

def save_checkpoint(self, loader_name, iteration, additional_state):
# Construct the checkpoint file path using the loader name
checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl")
checkpoint = {
'iteration': iteration, # Store the current iteration
'additional_state': additional_state # Store any additional state information
}
with open(checkpoint_path, 'wb') as f: # Open the file in write-binary mode
pickle.dump(checkpoint, f) # Serialize the checkpoint dictionary to the file

def load_checkpoint(self, loader_name):
# Construct the checkpoint file path using the loader name
checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl")
try:
with open(checkpoint_path, 'rb') as f: # Open the file in read-binary mode
return pickle.load(f) # Deserialize the checkpoint file and return it
except FileNotFoundError: # Handle the case where the checkpoint file does not exist
return None # Return None if the file is not found


# ====== Load the per subject arrythmia summary ======
df_summary = pd.read_csv(r'\\grove.ad.uconn.edu\research\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv')
df_summary['UID'] = df_summary['UID'].astype(str).str.zfill(3)

df_summary['sample_nonAF'] = df_summary['NSR'] + df_summary['PACPVC'] + df_summary['SVT']
df_summary['sample_AF'] = df_summary['AF']

df_summary['sample_nonAF_ratio'] = df_summary['sample_nonAF'] / (df_summary['sample_AF'] + df_summary['sample_nonAF'])

all_UIDs = df_summary['UID'].unique()
# ====================================================
# ====== AF trial separation ======
# R:\ENGR_Chon\Dong\Numbers\Pulsewatch_numbers\Fahimeh_CNNED_general_ExpertSystemwApplication\tbl_file_name\TrainingSet_final_segments
AF_trial_Fahimeh_train = ['402','410']
AF_trial_Fahimeh_test = ['301', '302', '305', '306', '307', '310', '311',
'312', '318', '319', '320', '321', '322', '324',
'325', '327', '329', '400', '406', '407', '409',
'414']
AF_trial_Fahimeh_did_not_use = ['405', '413', '415', '416', '420', '421', '422', '423']
AF_trial_paroxysmal_AF = ['408','419']

AF_trial_train = AF_trial_Fahimeh_train
AF_trial_test = AF_trial_Fahimeh_test
AF_trial_unlabeled = AF_trial_Fahimeh_did_not_use + AF_trial_paroxysmal_AF
print(f'AF trial: {len(AF_trial_train)} training subjects {AF_trial_train}')
print(f'AF trial: {len(AF_trial_test)} testing subjects {AF_trial_test}')
print(f'AF trial: {len(AF_trial_unlabeled)} unlabeled subjects {AF_trial_unlabeled}')
# =================================
# === Clinical trial AF subjects separation ===
clinical_trial_AF_subjects = ['005', '017', '026', '051', '075', '082']

remaining_UIDs = []
count_NSR = []
import math
for index, row in df_summary.iterrows():
UID = row['UID']
this_NSR = row['sample_nonAF']
if math.isnan(this_NSR):
# There is no segment in this subject, skip this UID.
print(f'---------UID {UID} has no segments.------------')
continue
if UID not in AF_trial_train and UID not in AF_trial_test and UID not in clinical_trial_AF_subjects \
and not UID[0] == '3' and not UID[0] == '4':
remaining_UIDs.append(UID)
count_NSR.append(this_NSR)

from numpy import random
random.seed(seed=42)
from numpy.random import choice
list_of_candidates = remaining_UIDs
number_of_items_to_pick = round(len(list_of_candidates) * 0.15) # 10% labeled for training, 5% for testing.
temp_sum = sum(count_NSR)
probability_distribution = [x/temp_sum for x in count_NSR]
probability_distribution = [(1-x/temp_sum)/ (len(count_NSR)-1) for x in count_NSR]# Subjects with fewer segments have higher chance to be selected. Make sure the sum is one.
draw = choice(list_of_candidates, number_of_items_to_pick,
p=probability_distribution, replace=False)

clinical_trial_train = list(draw[:round(len(list_of_candidates) * 0.1)])
clinical_trial_test_nonAF = list(draw[round(len(list_of_candidates) * 0.1):])
clinical_trial_test_temp = clinical_trial_test_nonAF + clinical_trial_AF_subjects
clinical_trial_test = []
for UID in clinical_trial_test_temp:
# UID 051 and maybe other UIDs had no segments (unknown reason).
if UID in all_UIDs:
clinical_trial_test.append(UID)

clinical_trial_unlabeled = []
for UID in all_UIDs:
if UID not in clinical_trial_train and UID not in clinical_trial_test and not UID[0] == '3' and not UID[0] == '4':
clinical_trial_unlabeled.append(UID)
print(f'Clinical trial: selected {len(clinical_trial_train)} UIDs for training {clinical_trial_train}')
print(f'Clinical trial: selected {len(clinical_trial_test)} UIDs for testing {clinical_trial_test}')
print(f'Clinical trial: selected {len(clinical_trial_unlabeled)} UIDs for unlabeled {clinical_trial_unlabeled}')

# Global parameters related to the machine learning model
num_latents = 6 # Define the number of latents
num_tasks = 4 # Define the number of tasks
num_inducing_points = 50 # Define the number of inducing points

# Initialize a dictionary to store various metrics
results = {
'train_loss': [],
'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []},
'test_metrics': None # Placeholder for final test metrics
}

# Set the device to CUDA if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Start of the main execution block
if __name__ == "__main__":
# Set the number of classes for the model
n_classes = 4

# Flags for running environment
is_linux = False
is_hpc = False
data_format = 'csv' # Set the data format
# Function call to get data paths based on the environment and format
data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc)

# Define batch size for loading data
batch_size = 512
# Load the training, validation, and test data
train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True)
val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True)
test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True)

# Initialize the CheckpointManager
checkpoint_manager = CheckpointManager(saving_path)

# Attempt to load a training checkpoint
train_checkpoint = checkpoint_manager.load_checkpoint('train')
start_iteration = train_checkpoint['iteration'] if train_checkpoint else 0

# Active learning iterations loop
active_learning_iterations = 10
n_samples = batch_size # Set the number of samples for uncertainty sampling
for iteration in tqdm(range(start_iteration, active_learning_iterations), desc='Active Learning', unit='iteration'):

# Training and active learning logic
for train_batch in train_loader:
train_x = train_batch['data'].view(train_batch['data'].size(0), -1).to(device)
train_y = train_batch['label'].to(device)
model, likelihood = train_gp_model(train_x, train_y, val_loader, num_iterations=10, n_classes=n_classes)
# Save checkpoint at the end of each iteration

uncertain_sample_indices = uncertainty_sampling(model, likelihood, val_loader, n_samples, n_components=2)
accumulated_indices = [idx for idx in uncertain_sample_indices]
train_loader = update_train_loader_with_uncertain_samples(train_loader, accumulated_indices, data_path, labels_path, batch_size)

for train_batch in tqdm(train_loader, desc='Batch Training', leave=False):
train_x = train_batch['data'].view(train_batch['data'].size(0), -1).to(device)
train_y = train_batch['label'].to(device)
model, likelihood = train_gp_model(train_x, train_y, val_loader, num_iterations=10, n_classes=n_classes)

val_metrics = evaluate_model_on_all_data(model, likelihood, val_loader, device, n_classes)
for metric in ['precision', 'recall', 'f1', 'auc_roc']:
results['validation_metrics'][metric].append(val_metrics[metric])

# Save checkpoint at the end of each iteration
additional_state = {
'model_state': model.state_dict(),
# Include other states like optimizer, scheduler, etc.
}
checkpoint_manager.save_checkpoint('train', iteration, additional_state)

This comment has been minimized.

Copy link
@doh16101

doh16101 Jan 23, 2024

Collaborator

Dear @lrm22005 ,

I think this code is the checkpoint saver, correct? I am going to apply it to the ss_main.py code because I need some of the variables at the end of each iteration so I can resume the training or replicate the error.


# Plot the training performance based on stored metrics
plot_training_performance(results['train_loss'], results['validation_metrics'])

# Final evaluation on test set
classification_result = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes=n_classes)
results['test_metrics'] = classification_result
plot_results(results)
print("Final Test Metrics:", results['test_metrics'])
Loading

0 comments on commit cf71da9

Please sign in to comment.