From cf71da9fd552711b36dc2435fdbca2e97dee764a Mon Sep 17 00:00:00 2001 From: Luis Roberto Mercado Diaz Date: Mon, 18 Dec 2023 13:35:19 -0500 Subject: [PATCH 1/2] Adding an example of checkpoint This is a good guide if you like to try using checkpoints over the data loaders. --- GP_Original_checkpoint.py | 334 +++++++++++++++++++++++++++++++++ option.py | 91 ++++++++- project_2.py | 384 +++++++++++++++++--------------------- project_NVIL.py | 32 ++-- 4 files changed, 605 insertions(+), 236 deletions(-) create mode 100644 GP_Original_checkpoint.py diff --git a/GP_Original_checkpoint.py b/GP_Original_checkpoint.py new file mode 100644 index 0000000..0bdceae --- /dev/null +++ b/GP_Original_checkpoint.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Dec 18 12:34:01 2023 + +@author: lrm22005 + +This approach keeps the checkpointing logic separate from the data loading logic, +which is a good practice in terms of code organization and reusability. +The should_save_checkpoint function is a placeholder for whatever logic you want to use to determine when to save a checkpoint. +It could be based on time, number of batches processed, or any other criterion. +Remember that while this method saves the progress in terms of batch index, +it does not automatically save the state of your model, optimizer, or any other components of your training loop. You should handle those separately as needed. + +The code is designed to train a machine learning model (Gaussian Process) using an active learning approach. +It heavily relies on custom functions like load_data_split_batched, train_gp_model, uncertainty_sampling, update_train_loader_with_uncertain_samples, evaluate_model_on_all_data, and plotting functions, which are not defined in this snippet. +Active learning is used to iteratively select the most informative samples to improve the model iteratively. +Checkpoints are used to save the state of the model at each iteration, allowing for recovery and continuation of training in case of interruption. +The code evaluates model performance on both validation and test datasets, storing various metrics for analysis. +""" +import os +import torch +import pandas as pd +import numpy as np +from torch.utils.data import DataLoader, Dataset +from torchvision.transforms import ToTensor +from sklearn.preprocessing import StandardScaler +from PIL import Image +import pickle + +from tqdm import tqdm +from GP_original_data import map_samples_to_uids, MultitaskGPModel, train_gp_model, parse_classification_report +from GP_original_data import evaluate_model_on_all_data, uncertainty_sampling, label_samples +from GP_original_data import update_train_loader_with_uncertain_samples, plot_training_performance, plot_results + +def get_data_paths(data_format, is_linux=False, is_hpc=False): + if is_linux: + base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch" + labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn" + saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis" + elif is_hpc: + base_path = "/gpfs/scratchfs1/kic14002/doh16101" + labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005" + saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" + else: + # R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch + base_path = "R:\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch" + labels_base_path = "R:\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn" + saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case" + if data_format == 'csv': + data_path = os.path.join(base_path, "TFS_csv") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + elif data_format == 'png': + data_path = os.path.join(base_path, "TFS_plots") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + else: + raise ValueError("Invalid data format. Choose 'csv' or 'png.") + return data_path, labels_path, saving_path + +class CustomDataset(Dataset): + def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False): + self.data_path = data_path + self.labels_path = labels_path + self.UIDs = UIDs + self.standardize = standardize + self.data_format = data_format + self.read_all_labels = read_all_labels + self.transforms = ToTensor() + self.refresh_dataset() + + def refresh_dataset(self): + # Extract unique segment names and their corresponding labels + self.segment_names, self.labels = self.extract_segment_names_and_labels() + + def add_uids(self, new_uids): + # Ensure new UIDs are unique and not already in the dataset + unique_new_uids = [uid for uid in new_uids if uid not in self.UIDs] + + # Add unique new UIDs and refresh the dataset + self.UIDs.extend(unique_new_uids) + self.refresh_dataset() + + def __len__(self): + return len(self.segment_names) + + def __getitem__(self, idx): + segment_name = self.segment_names[idx] + label = self.labels[segment_name] + + # Load data on-the-fly based on the segment_name + time_freq_tensor = self.load_data(segment_name) + + return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name} + + def extract_segment_names_and_labels(self): + segment_names = [] + labels = {} + + for UID in self.UIDs: + label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv") + if os.path.exists(label_file): + label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) + label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) + for idx, segment_name in enumerate(label_segment_names): + label_val = label_data['label'].values[idx] + if self.read_all_labels: + # Assign -1 if label is not in [0, 1, 2, 3] + labels[segment_name] = label_val if label_val in [0, 1, 2, 3] else -1 + if segment_name not in segment_names: + segment_names.append(segment_name) + else: + # Only add segments with labels in [0, 1, 2, 3] + if label_val in [0, 1, 2, 3] and segment_name not in segment_names: + segment_names.append(segment_name) + labels[segment_name] = label_val + + return segment_names, labels + + def load_data(self, segment_name): + data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0]) + seg_path = os.path.join(data_path_UID, segment_name + '_filt_STFT.csv') + + try: + if self.data_format == 'csv' and seg_path.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif self.data_format == 'png' and seg_path.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + raise ValueError("Unsupported file format") + + if self.standardize: + time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data + + return time_freq_tensor.clone() + + except Exception as e: + print(f"Error processing segment: {segment_name}. Exception: {str(e)}") + return torch.zeros((1, 128, 128)) # Return zeros in case of an error + + def standard_scaling(self, data): + scaler = StandardScaler() + data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) + return torch.Tensor(data) + +def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=True, drop_last=False): + dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last) + return dataloader + +class CheckpointManager: + def __init__(self, checkpoint_dir): + self.checkpoint_dir = checkpoint_dir # Store the directory path for checkpoints + if not os.path.exists(checkpoint_dir): # Check if the directory exists + os.makedirs(checkpoint_dir) # Create the directory if it does not exist + + def save_checkpoint(self, loader_name, iteration, additional_state): + # Construct the checkpoint file path using the loader name + checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl") + checkpoint = { + 'iteration': iteration, # Store the current iteration + 'additional_state': additional_state # Store any additional state information + } + with open(checkpoint_path, 'wb') as f: # Open the file in write-binary mode + pickle.dump(checkpoint, f) # Serialize the checkpoint dictionary to the file + + def load_checkpoint(self, loader_name): + # Construct the checkpoint file path using the loader name + checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl") + try: + with open(checkpoint_path, 'rb') as f: # Open the file in read-binary mode + return pickle.load(f) # Deserialize the checkpoint file and return it + except FileNotFoundError: # Handle the case where the checkpoint file does not exist + return None # Return None if the file is not found + + +# ====== Load the per subject arrythmia summary ====== +df_summary = pd.read_csv(r'\\grove.ad.uconn.edu\research\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv') +df_summary['UID'] = df_summary['UID'].astype(str).str.zfill(3) + +df_summary['sample_nonAF'] = df_summary['NSR'] + df_summary['PACPVC'] + df_summary['SVT'] +df_summary['sample_AF'] = df_summary['AF'] + +df_summary['sample_nonAF_ratio'] = df_summary['sample_nonAF'] / (df_summary['sample_AF'] + df_summary['sample_nonAF']) + +all_UIDs = df_summary['UID'].unique() +# ==================================================== +# ====== AF trial separation ====== +# R:\ENGR_Chon\Dong\Numbers\Pulsewatch_numbers\Fahimeh_CNNED_general_ExpertSystemwApplication\tbl_file_name\TrainingSet_final_segments +AF_trial_Fahimeh_train = ['402','410'] +AF_trial_Fahimeh_test = ['301', '302', '305', '306', '307', '310', '311', + '312', '318', '319', '320', '321', '322', '324', + '325', '327', '329', '400', '406', '407', '409', + '414'] +AF_trial_Fahimeh_did_not_use = ['405', '413', '415', '416', '420', '421', '422', '423'] +AF_trial_paroxysmal_AF = ['408','419'] + +AF_trial_train = AF_trial_Fahimeh_train +AF_trial_test = AF_trial_Fahimeh_test +AF_trial_unlabeled = AF_trial_Fahimeh_did_not_use + AF_trial_paroxysmal_AF +print(f'AF trial: {len(AF_trial_train)} training subjects {AF_trial_train}') +print(f'AF trial: {len(AF_trial_test)} testing subjects {AF_trial_test}') +print(f'AF trial: {len(AF_trial_unlabeled)} unlabeled subjects {AF_trial_unlabeled}') +# ================================= +# === Clinical trial AF subjects separation === +clinical_trial_AF_subjects = ['005', '017', '026', '051', '075', '082'] + +remaining_UIDs = [] +count_NSR = [] +import math +for index, row in df_summary.iterrows(): + UID = row['UID'] + this_NSR = row['sample_nonAF'] + if math.isnan(this_NSR): + # There is no segment in this subject, skip this UID. + print(f'---------UID {UID} has no segments.------------') + continue + if UID not in AF_trial_train and UID not in AF_trial_test and UID not in clinical_trial_AF_subjects \ + and not UID[0] == '3' and not UID[0] == '4': + remaining_UIDs.append(UID) + count_NSR.append(this_NSR) + +from numpy import random +random.seed(seed=42) +from numpy.random import choice +list_of_candidates = remaining_UIDs +number_of_items_to_pick = round(len(list_of_candidates) * 0.15) # 10% labeled for training, 5% for testing. +temp_sum = sum(count_NSR) +probability_distribution = [x/temp_sum for x in count_NSR] +probability_distribution = [(1-x/temp_sum)/ (len(count_NSR)-1) for x in count_NSR]# Subjects with fewer segments have higher chance to be selected. Make sure the sum is one. +draw = choice(list_of_candidates, number_of_items_to_pick, + p=probability_distribution, replace=False) + +clinical_trial_train = list(draw[:round(len(list_of_candidates) * 0.1)]) +clinical_trial_test_nonAF = list(draw[round(len(list_of_candidates) * 0.1):]) +clinical_trial_test_temp = clinical_trial_test_nonAF + clinical_trial_AF_subjects +clinical_trial_test = [] +for UID in clinical_trial_test_temp: + # UID 051 and maybe other UIDs had no segments (unknown reason). + if UID in all_UIDs: + clinical_trial_test.append(UID) + +clinical_trial_unlabeled = [] +for UID in all_UIDs: + if UID not in clinical_trial_train and UID not in clinical_trial_test and not UID[0] == '3' and not UID[0] == '4': + clinical_trial_unlabeled.append(UID) +print(f'Clinical trial: selected {len(clinical_trial_train)} UIDs for training {clinical_trial_train}') +print(f'Clinical trial: selected {len(clinical_trial_test)} UIDs for testing {clinical_trial_test}') +print(f'Clinical trial: selected {len(clinical_trial_unlabeled)} UIDs for unlabeled {clinical_trial_unlabeled}') + +# Global parameters related to the machine learning model +num_latents = 6 # Define the number of latents +num_tasks = 4 # Define the number of tasks +num_inducing_points = 50 # Define the number of inducing points + +# Initialize a dictionary to store various metrics +results = { + 'train_loss': [], + 'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []}, + 'test_metrics': None # Placeholder for final test metrics +} + +# Set the device to CUDA if available, otherwise use CPU +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Start of the main execution block +if __name__ == "__main__": + # Set the number of classes for the model + n_classes = 4 + + # Flags for running environment + is_linux = False + is_hpc = False + data_format = 'csv' # Set the data format + # Function call to get data paths based on the environment and format + data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) + + # Define batch size for loading data + batch_size = 512 + # Load the training, validation, and test data + train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True) + val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True) + test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True) + + # Initialize the CheckpointManager + checkpoint_manager = CheckpointManager(saving_path) + + # Attempt to load a training checkpoint + train_checkpoint = checkpoint_manager.load_checkpoint('train') + start_iteration = train_checkpoint['iteration'] if train_checkpoint else 0 + + # Active learning iterations loop + active_learning_iterations = 10 + n_samples = batch_size # Set the number of samples for uncertainty sampling + for iteration in tqdm(range(start_iteration, active_learning_iterations), desc='Active Learning', unit='iteration'): + + # Training and active learning logic + for train_batch in train_loader: + train_x = train_batch['data'].view(train_batch['data'].size(0), -1).to(device) + train_y = train_batch['label'].to(device) + model, likelihood = train_gp_model(train_x, train_y, val_loader, num_iterations=10, n_classes=n_classes) + # Save checkpoint at the end of each iteration + + uncertain_sample_indices = uncertainty_sampling(model, likelihood, val_loader, n_samples, n_components=2) + accumulated_indices = [idx for idx in uncertain_sample_indices] + train_loader = update_train_loader_with_uncertain_samples(train_loader, accumulated_indices, data_path, labels_path, batch_size) + + for train_batch in tqdm(train_loader, desc='Batch Training', leave=False): + train_x = train_batch['data'].view(train_batch['data'].size(0), -1).to(device) + train_y = train_batch['label'].to(device) + model, likelihood = train_gp_model(train_x, train_y, val_loader, num_iterations=10, n_classes=n_classes) + + val_metrics = evaluate_model_on_all_data(model, likelihood, val_loader, device, n_classes) + for metric in ['precision', 'recall', 'f1', 'auc_roc']: + results['validation_metrics'][metric].append(val_metrics[metric]) + + # Save checkpoint at the end of each iteration + additional_state = { + 'model_state': model.state_dict(), + # Include other states like optimizer, scheduler, etc. + } + checkpoint_manager.save_checkpoint('train', iteration, additional_state) + + # Plot the training performance based on stored metrics + plot_training_performance(results['train_loss'], results['validation_metrics']) + + # Final evaluation on test set + classification_result = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes=n_classes) + results['test_metrics'] = classification_result + plot_results(results) + print("Final Test Metrics:", results['test_metrics']) diff --git a/option.py b/option.py index c1d8352..adf4f29 100644 --- a/option.py +++ b/option.py @@ -13,14 +13,66 @@ from torchvision.transforms import ToTensor from sklearn.preprocessing import StandardScaler from PIL import Image +import os +import torch +import pandas as pd +import numpy as np +from torch.utils.data import DataLoader, Dataset, TensorDataset +from torchvision.transforms import ToTensor +from sklearn.preprocessing import StandardScaler +from PIL import Image + +import logging +import numpy as np +import pandas as pd +import torch +from torch.distributions import MultivariateNormal +from torch.utils.data import DataLoader, TensorDataset +import matplotlib.pyplot as plt +from sklearn.decomposition import PCA, IncrementalPCA +from sklearn.manifold import TSNE +from sklearn.cluster import KMeans +from sklearn.cluster import MiniBatchKMeans +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import silhouette_score, adjusted_rand_score, adjusted_mutual_info_score, davies_bouldin_score +import seaborn as sns +from PIL import Image # Import the Image module + + +def get_data_paths(data_format, is_linux=False, is_hpc=False): + if is_linux: + base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch" + labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn" + saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis" + elif is_hpc: + base_path = "/gpfs/scratchfs1/kic14002/doh16101" + labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005" + saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" + else: + # R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch + base_path = "\\grove.ad.uconn.edu\research\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch" + labels_base_path = "\\grove.ad.uconn.edu\research\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn" + saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case" + if data_format == 'csv': + data_path = os.path.join(base_path, "TFS_csv") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + elif data_format == 'png': + data_path = os.path.join(base_path, "TFS_plots") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + else: + raise ValueError("Invalid data format. Choose 'csv' or 'png.") + return data_path, labels_path, saving_path class CustomDataset(Dataset): - def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv'): + def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False): self.data_path = data_path self.labels_path = labels_path self.UIDs = UIDs self.standardize = standardize self.data_format = data_format + self.read_all_labels = read_all_labels self.transforms = ToTensor() # Extract unique segment names and their corresponding labels @@ -36,7 +88,7 @@ def __getitem__(self, idx): # Load data on-the-fly based on the segment_name time_freq_tensor = self.load_data(segment_name) - return {'data': time_freq_tensor.unsqueeze(0), 'label': label, 'segment_name': segment_name} + return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name} def extract_segment_names_and_labels(self): segment_names = [] @@ -48,15 +100,23 @@ def extract_segment_names_and_labels(self): label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) for idx, segment_name in enumerate(label_segment_names): - if segment_name not in segment_names: - segment_names.append(segment_name) - labels[segment_name] = label_data['label'].values[idx] + label_val = label_data['label'].values[idx] + if self.read_all_labels: + # Assign -1 if label is not in [0, 1, 2, 3] + labels[segment_name] = label_val if label_val in [0, 1, 2, 3] else -1 + if segment_name not in segment_names: + segment_names.append(segment_name) + else: + # Only add segments with labels in [0, 1, 2, 3] + if label_val in [0, 1, 2, 3] and segment_name not in segment_names: + segment_names.append(segment_name) + labels[segment_name] = label_val return segment_names, labels def load_data(self, segment_name): data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0]) - seg_path = os.path.join(data_path_UID, segment_name + '_filt.csv') + seg_path = os.path.join(data_path_UID, segment_name + '_filt_STFT.csv') try: if self.data_format == 'csv' and seg_path.endswith('.csv'): @@ -83,8 +143,8 @@ def standard_scaling(self, data): data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) return torch.Tensor(data) -def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv'): - dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format) +def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=True): + dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) return dataloader @@ -93,3 +153,18 @@ def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardiz # len(train_loader) # number of batches # len(train_loader) + +is_linux = False # Set to True if running on Linux, False if on Windows +is_hpc = False # Set to True if running on an HPC, False if on Windows +data_format = 'csv' # Choose 'csv' or 'png' +data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) + +clinical_trial_train = [clinical_trial_train[0]] +clinical_trial_test = [clinical_trial_test[0]] +clinical_trial_unlabeled = [clinical_trial_unlabeled[0]] + +# Example usage: +batch_size = 64 +train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=False, data_format='csv', read_all_labels=False) +val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=False, data_format='csv', read_all_labels=False) +test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=False, data_format='csv', read_all_labels=False) diff --git a/project_2.py b/project_2.py index e25fe0a..8647a54 100644 --- a/project_2.py +++ b/project_2.py @@ -14,7 +14,7 @@ ############################################################################################################################################## ############################################################################################################################################## # ====== Load the per subject arrythmia summary ====== -df_summary = pd.read_csv(r'R:\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv') +df_summary = pd.read_csv(r'\\grove.ad.uconn.edu\research\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv') df_summary['UID'] = df_summary['UID'].astype(str).str.zfill(3) df_summary['sample_nonAF'] = df_summary['NSR'] + df_summary['PACPVC'] + df_summary['SVT'] @@ -98,7 +98,8 @@ import pandas as pd import torch from torch.distributions import MultivariateNormal -from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data import DataLoader, TensorDataset, Dataset +from torchvision.transforms import ToTensor import matplotlib.pyplot as plt from sklearn.decomposition import PCA, IncrementalPCA from sklearn.manifold import TSNE @@ -110,24 +111,8 @@ from PIL import Image # Import the Image module os.environ['OMP_NUM_THREADS'] = '3' -# Create a logger -logger = logging.getLogger(__name__) -logging.basicConfig(filename='error_log.txt', level=logging.ERROR) -logging.basicConfig(level=logging.INFO) - -# Create a logger for progress monitoring -progress_logger = logging.getLogger('progress') -progress_logger.setLevel(logging.INFO) -progress_handler = logging.FileHandler('progress.log') -progress_formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s') -progress_handler.setFormatter(progress_formatter) -progress_logger.addHandler(progress_handler) - -def log_progress(message): - progress_logger.info(message) def get_data_paths(data_format, is_linux=False, is_hpc=False): - log_progress("Code execution get_data_paths") if is_linux: base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch" labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn" @@ -137,8 +122,8 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False): labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005" saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" else: - base_path = "R:\\ENGR_Chon\\Dong\\MATLAB_generate_results\\NIH_PulseWatch" - labels_base_path = "R:\\ENGR_Chon\\NIH_Pulsewatch_Database\\Adjudication_UConn" + base_path = "\\grove.ad.uconn.edu\research\\ENGR_Chon\\Dong\\MATLAB_generate_results\\NIH_PulseWatch" + labels_base_path = "\\grove.ad.uconn.edu\research\\ENGR_Chon\\NIH_Pulsewatch_Database\\Adjudication_UConn" saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case" if data_format == 'csv': data_path = os.path.join(base_path, "TFS_csv") @@ -150,123 +135,109 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False): saving_path = os.path.join(saving_base_path, "Project_1_analysis") else: raise ValueError("Invalid data format. Choose 'csv' or 'png.") - log_progress("Code execution completed get_data_paths") return data_path, labels_path, saving_path -def load_data_split(data_path, labels_path, UIDs, standardize=True, data_format='csv'): - if data_format not in ['csv', 'png']: - raise ValueError("Invalid data_format. Choose 'csv' or 'png.") +class CustomDataset(Dataset): + def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False): + self.data_path = data_path + self.labels_path = labels_path + self.UIDs = UIDs + self.standardize = standardize + self.data_format = data_format + self.read_all_labels = read_all_labels + self.transforms = ToTensor() + + # Extract unique segment names and their corresponding labels + self.segment_names, self.labels = self.extract_segment_names_and_labels() + + def __len__(self): + return len(self.segment_names) + + def __getitem__(self, idx): + segment_name = self.segment_names[idx] + label = self.labels[segment_name] + + # Load data on-the-fly based on the segment_name + time_freq_tensor = self.load_data(segment_name) + + return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name} + + def extract_segment_names_and_labels(self): + segment_names = [] + labels = {} + + for UID in self.UIDs: + label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv") + if os.path.exists(label_file): + label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) + label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) + for idx, segment_name in enumerate(label_segment_names): + label_val = label_data['label'].values[idx] + if self.read_all_labels: + # Assign -1 if label is not in [0, 1, 2, 3] + labels[segment_name] = label_val if label_val in [0, 1, 2, 3] else -1 + if segment_name not in segment_names: + segment_names.append(segment_name) + else: + # Only add segments with labels in [0, 1, 2, 3] + if label_val in [0, 1, 2, 3] and segment_name not in segment_names: + segment_names.append(segment_name) + labels[segment_name] = label_val + + return segment_names, labels + + def load_data(self, segment_name): + data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0]) + seg_path = os.path.join(data_path_UID, segment_name + '_filt_STFT.csv') + + try: + if self.data_format == 'csv' and seg_path.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif self.data_format == 'png' and seg_path.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + raise ValueError("Unsupported file format") + + if self.standardize: + time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data + + return time_freq_tensor.clone() + + except Exception as e: + print(f"Error processing segment: {segment_name}. Exception: {str(e)}") + return torch.zeros((1, 128, 128)) # Return zeros in case of an error + + def standard_scaling(self, data): + scaler = StandardScaler() + data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) + return torch.Tensor(data) + +def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=True): + dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) + return dataloader - data, labels, segment_names = load_split_data(data_path, labels_path, UIDs, standardize, data_format) - - return data, labels, segment_names - -def extract_labels(UID_list, labels_path, segment_names): - labels = {} - for UID in UID_list: - label_file = os.path.join(labels_path, UID + "_final_attemp_4_1_Dong.csv") - if os.path.exists(label_file): - label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) - label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) - labels[UID] = (label_data['label'].values, label_segment_names.values) - - return labels - -# Standardize the data -def standard_scaling(data): - scaler = StandardScaler() - data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) - return torch.Tensor(data) - -def load_split_data(data_path, labels_path, UIDs, standardize=True, data_format='csv'): - X_data_original = [] # Store original data without standardization - segment_names = [] - - for UID in UIDs: - data_path_UID = os.path.join(data_path, UID) - dir_list_seg = os.listdir(data_path_UID) - - # for seg in dir_list_seg[:len(dir_list_seg)]: # Limiting to 50 segments - for seg in dir_list_seg[:500]: # Limiting to the first 50 segments - seg_path = os.path.join(data_path_UID, seg) - - try: - if data_format == 'csv' and seg.endswith('.csv'): - time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) - time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) - elif data_format == 'png' and seg.endswith('.png'): - img = Image.open(seg_path) - img_data = np.array(img) - time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) - else: - continue # Skip other file formats - - X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data - - # Extract and store segment names - # Change here: Use the segment name from the CSV file directly - segment_names.append(seg.split('_filt')[0]) - - except Exception as e: - logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}") - # You can also add more information to the error log, such as the value of time_freq_plot. - - X_data_original = torch.cat(X_data_original, 0) - - if standardize: - X_data_original = standard_scaling(X_data_original) # Standardize the data - - # Extract labels from CSV files - labels = extract_labels(UIDs, labels_path, segment_names) - - important_labels = [0, 1, 2, 3] # List of important labels - - # Initialize labels for segments as unlabeled (-1) - segment_labels = {segment_name: -1 for segment_name in segment_names} - - for UID in labels.keys(): - if UID not in UIDs: - # Skip UIDs that are not in the dataset - continue - - label_data, label_segment_names = labels[UID] - - for idx, segment_label in enumerate(label_data): - segment_name = label_segment_names[idx] - - # Change here: Only update labels for segments present in the loaded data - if segment_name in segment_names: - if segment_label in important_labels: - segment_labels[segment_name] = segment_label - else: - # Set labels that are not in the important list as -1 (Unlabeled) - segment_labels[segment_name] = -1 - - return X_data_original, segment_names, {seg: segment_labels[seg] for seg in segment_names} ############################################################################################################################################## ############################################################################################################################################## ############################################################################################################################################## -log_progress("Code execution started.") is_linux = False # Set to True if running on Linux, False if on Windows is_hpc = False # Set to True if running on an HPC, False if on Windows data_format = 'csv' # Choose 'csv' or 'png' data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) -# Load data for the training set -train_data, train_segment_names, train_labels = load_data_split( - data_path=data_path, labels_path=labels_path, UIDs=clinical_trial_train, standardize=True, data_format='csv' -) - -# Load data for the validation set -val_data, val_segment_names, val_labels = load_data_split( - data_path, labels_path, UIDs=clinical_trial_test, standardize=True, data_format='csv' -) +clinical_trial_train = [clinical_trial_train[0]] +clinical_trial_test = [clinical_trial_test[0]] +clinical_trial_unlabeled = [clinical_trial_unlabeled[0]] -# Load data for the test set (unlabeled) -test_data, test_segment_names, test_labels = load_data_split( - data_path, labels_path, UIDs=clinical_trial_unlabeled, standardize=True, data_format='csv' -) +# Example usage: +batch_size = 32 +train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=False, data_format='csv', read_all_labels=False) +val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=False, data_format='csv', read_all_labels=False) +test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=False, data_format='csv', read_all_labels=False) ############################################################################################################################################## ############################################################################################################################################## ############################################################################################################################################## @@ -322,8 +293,6 @@ def visualize_pacmap(X_transformed, y, title="Scatter Plot"): ############################################################################################################################################## ############################################################################################################################################## ############################################################################################################################################## - - import torch import torch.nn as nn import torch.optim as optim @@ -344,111 +313,92 @@ def init_gmm_parameters(K, D): alpha_variational = torch.randn(K, requires_grad=True) return miu_variational, log_sigma_variational, alpha_variational -# Mean-Field Variational Inference for GMM with labeled/unlabeled data -def perform_mfvi(data, labels, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=False): - N, D = data.shape[0], data.shape[1] * data.shape[2] # Calculate feature dimension D +# This function computes the log PDF of a multivariate normal distribution +def multivariate_normal_log_pdf_MFVI(x, mu, sigma_sq): + # x: Data points (N x C x H x W) + # mu: Means of the components (K x D) + # sigma_sq: Variances of the components (K x D) + + N, C, H, W = x.shape # Get the dimensions of the data tensor + D = C * H * W + K, _ = mu.shape + + x_flat = x.reshape(N, D) # Flatten the data to N x D + + log_p = torch.empty(N, K, dtype=x.dtype, device=x.device) + for k in range(K): + # Create a covariance matrix for each component + cov_matrix = torch.diag(sigma_sq[k]) + + # Calculate the log PDF for each data point + for n in range(N): + data_point = x_flat[n] # The data is already in the correct shape (D) + mvn = dist.MultivariateNormal(mu[k], cov_matrix) + log_p[n, k] = mvn.log_prob(data_point) + + return log_p - # Initialize GMM parameters - miu_variational, log_sigma_variational, alpha_variational = init_gmm_parameters(K, D) +import gc # Import the garbage collection module +def perform_mfvi(dataloader, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=False): + # Extract the dimensionality of the data from the first batch in the DataLoader + first_batch = next(iter(dataloader)) + data_sample = first_batch['data'] + _, C, H, W = data_sample.shape # Extract batch shape components + D = C * H * W # Calculate feature dimension D + + # Initialize the variational parameters for the GMM + miu_variational = torch.randn(K, D, requires_grad=True) + log_sigma_variational = torch.randn(K, D, requires_grad=True) + alpha_variational = torch.randn(K, requires_grad=True) + # Define the optimizer for gradient descent optimizer = torch.optim.Adam([miu_variational, log_sigma_variational, alpha_variational], lr=0.001) prev_elbo = float('-inf') iteration = 0 while True: - # Initialize gradients - optimizer.zero_grad() - - # Compute the Gaussian means and covariances from variational parameters - sigma_variational_sq = torch.exp(log_sigma_variational.clone()) - - # Calculate the responsibilities (E[zi]) - log_pi_variational = torch.log_softmax(alpha_variational, dim=0) - log_resp = log_pi_variational.unsqueeze(0) + multivariate_normal_log_pdf_MFVI(data, miu_variational, sigma_variational_sq) - log_resp_max = log_resp.max(dim=1, keepdim=True).values.clone() - resp = torch.exp(log_resp - log_resp_max).clone() - resp /= resp.sum(dim=1, keepdim=True) - - # Add labels to the ELBO for the labeled data - labeled_mask = labels >= 0 - elbo_labeled = -torch.sum(resp[labeled_mask] * log_resp[labeled_mask]) + torch.sum(resp[labeled_mask] * torch.log(resp[labeled_mask])) - - # Add entropy regularization for the unlabeled data - unlabeled_mask = labels == -1 - entropy_regularizer = -torch.sum(resp[unlabeled_mask] * torch.log(resp[unlabeled_mask])) - - # Combine the ELBO terms - elbo = elbo_labeled + entropy_regularizer - - # Perform backpropagation with retain_graph=True - elbo.backward(retain_graph=True) - - # Check for NaN values - if torch.isnan(miu_variational).any() or torch.isnan(log_sigma_variational).any() or torch.isnan(alpha_variational).any(): - print("NaN values detected in parameters. Stopping optimization.") - break - # Clip gradients to prevent exploding gradients - torch.nn.utils.clip_grad_norm_([miu_variational, log_sigma_variational, alpha_variational], max_norm=5.0) - - # Update the variational parameters - optimizer.step() - - # Print progress - if (iteration + 1) % 100 == 0: - print(f"Iteration {iteration + 1}/{n_optimization_iterations}") - - # Check for convergence based on change in ELBO and gradient clipping - if iteration > 0 and abs(elbo - prev_elbo) < convergence_threshold: - print(f"Converged after {iteration + 1} iterations") - break - - # Check for gradient clipping - if any(param.grad is not None and torch.isnan(param.grad).any() for param in [miu_variational, log_sigma_variational, alpha_variational]): - print("Gradient contains NaN values. Stopping optimization.") - break - - prev_elbo = elbo - iteration += 1 + batch_elbo = 0 + for batch in dataloader: + data = batch['data'] + N = data.shape[0] # Extract batch size + + optimizer.zero_grad() + sigma_variational_sq = torch.exp(log_sigma_variational.clone()) + log_pi_variational = torch.digamma(alpha_variational) - torch.digamma(alpha_variational.sum()) + log_resp = log_pi_variational.unsqueeze(0) + multivariate_normal_log_pdf_MFVI(data, miu_variational, sigma_variational_sq) + log_resp_max = log_resp.max(dim=1, keepdim=True).values.clone() + resp = torch.exp(log_resp - log_resp_max).clone() + resp /= resp.sum(dim=1, keepdim=True) + elbo = -torch.sum(resp * log_resp) + torch.sum(resp * torch.log(resp)) + elbo.backward(retain_graph=True) + optimizer.step() + batch_elbo += elbo.item() + + gc.collect() + + # Check for convergence + if run_until_convergence: + if iteration > 0 and abs(batch_elbo - prev_elbo) < convergence_threshold: + print(f"Converged after {iteration + 1} iterations") + break + elif iteration == n_optimization_iterations - 1: + print("Reached the specified number of iterations.") - # Extract the learned parameters + prev_elbo = batch_elbo + iteration += 1 + print(f"Iteration {iteration}/{n_optimization_iterations}, ELBO: {batch_elbo}") + miu = miu_variational.detach().numpy() pi = torch.softmax(alpha_variational, dim=0).detach().numpy() - return miu, pi, resp - -# This function computes the log PDF of a multivariate normal distribution for MFVI -def multivariate_normal_log_pdf_MFVI(x, mu, sigma_sq): - # x: Data points (N x D) - # mu: Means of the components (K x D) - # sigma_sq: Variances of the components (K x D) - N, D = x.shape[0], x.shape[1] # Corrected line - K, _ = mu.shape - - log_p = torch.empty(N, K, dtype=x.dtype, device=x.device) - for k in range(K): - cov_matrix = torch.diag(sigma_sq[k]) - mvn = dist.MultivariateNormal(mu[k], cov_matrix) - log_p[:, k] = mvn.log_prob(x.view(N, -1)) # Reshape x to match event_shape - - return log_p + # The function will return the variational parameters and the expected labels for each minibatch + return miu, pi # Example usage for training, validation, and testing -# Assuming train_data, val_data, test_data are already loaded and standardized -train_labels_tensor = torch.Tensor(list(train_labels.values())) -val_labels_tensor = torch.Tensor(list(val_labels.values())) -test_labels_tensor = torch.Tensor(list(test_labels.values())) - -# Perform MFVI on training data -miu_train, pi_train, resp_train = perform_mfvi(train_data, train_labels_tensor, K=4, n_optimization_iterations=1000, run_until_convergence=True) -mfvi_labels_train = torch.argmax(resp_train, dim=1).numpy() - -# Perform MFVI on validation data -miu_val, pi_val, resp_val = perform_mfvi(val_data, val_labels_tensor, K=4, n_optimization_iterations=1000, run_until_convergence=True) -mfvi_labels_val = torch.argmax(resp_val, dim=1).numpy() -# Perform MFVI on test data (unlabeled) -miu_test, pi_test, resp_test = perform_mfvi(test_data, torch.full_like(test_labels_tensor, -1), K=4, n_optimization_iterations=1000, run_until_convergence=True) -mfvi_labels_test = torch.argmax(resp_test, dim=1).numpy() +miu_train, pi_train = perform_mfvi(train_loader, K=4, n_optimization_iterations=20, run_until_convergence=True) +miu_val, pi_val = perform_mfvi(val_loader, K=4, n_optimization_iterations=1000, run_until_convergence=True) +miu_test, pi_test = perform_mfvi(test_loader, K=4, n_optimization_iterations=1000, run_until_convergence=True) ############################################################################################################################################## ############################################################################################################################################## diff --git a/project_NVIL.py b/project_NVIL.py index d54789e..313c1aa 100644 --- a/project_NVIL.py +++ b/project_NVIL.py @@ -105,12 +105,13 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False): return data_path, labels_path, saving_path class CustomDataset(Dataset): - def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv'): + def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False): self.data_path = data_path self.labels_path = labels_path self.UIDs = UIDs self.standardize = standardize self.data_format = data_format + self.read_all_labels = read_all_labels self.transforms = ToTensor() # Extract unique segment names and their corresponding labels @@ -126,7 +127,7 @@ def __getitem__(self, idx): # Load data on-the-fly based on the segment_name time_freq_tensor = self.load_data(segment_name) - return {'data': time_freq_tensor.unsqueeze(0), 'label': label, 'segment_name': segment_name} + return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name} def extract_segment_names_and_labels(self): segment_names = [] @@ -138,9 +139,17 @@ def extract_segment_names_and_labels(self): label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) for idx, segment_name in enumerate(label_segment_names): - if segment_name not in segment_names: - segment_names.append(segment_name) - labels[segment_name] = label_data['label'].values[idx] + label_val = label_data['label'].values[idx] + if self.read_all_labels: + # Assign -1 if label is not in [0, 1, 2, 3] + labels[segment_name] = label_val if label_val in [0, 1, 2, 3] else -1 + if segment_name not in segment_names: + segment_names.append(segment_name) + else: + # Only add segments with labels in [0, 1, 2, 3] + if label_val in [0, 1, 2, 3] and segment_name not in segment_names: + segment_names.append(segment_name) + labels[segment_name] = label_val return segment_names, labels @@ -173,11 +182,12 @@ def standard_scaling(self, data): data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) return torch.Tensor(data) -def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv'): - dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format) +def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=True): + dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) return dataloader + is_linux = False # Set to True if running on Linux, False if on Windows is_hpc = False # Set to True if running on an HPC, False if on Windows data_format = 'csv' # Choose 'csv' or 'png' @@ -188,10 +198,10 @@ def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardiz clinical_trial_unlabeled = [clinical_trial_unlabeled[0]] # Example usage: -batch_size = 64 -train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size) -val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size) -test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size) +batch_size = 128 +train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=False, data_format='csv', read_all_labels=False) +val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=False, data_format='csv', read_all_labels=False) +test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=False, data_format='csv', read_all_labels=False) # def imshow(image, ax=None, title=None, normalize=True): From cad90d49d88c74302ab8cbb7e803274247f0d5c7 Mon Sep 17 00:00:00 2001 From: Luis Roberto Mercado Diaz Date: Mon, 18 Dec 2023 17:29:00 -0500 Subject: [PATCH 2/2] Update --- .gitignore | 10 +- semisupervised_method.py | 779 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 787 insertions(+), 2 deletions(-) create mode 100644 semisupervised_method.py diff --git a/.gitignore b/.gitignore index 0abf303..11d1435 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,9 @@ -error_log.txt -progress.log +__pycache__/GP_original_data.cpython-311.pyc +model_checkpoint_tsne.pt +model_checkpoint_full.pt +simple_CNN.py +VAE.py +model_checkpoint.pt +GP_original_data.py +Attention_network.py diff --git a/semisupervised_method.py b/semisupervised_method.py new file mode 100644 index 0000000..547ffcd --- /dev/null +++ b/semisupervised_method.py @@ -0,0 +1,779 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Dec 18 13:49:26 2023 + +@author: lrm22005 +""" +import os +import torch +import pandas as pd +import numpy as np +from torch.utils.data import DataLoader, Dataset, TensorDataset +from torchvision.transforms import ToTensor +from sklearn.preprocessing import StandardScaler +from PIL import Image + +import logging +import numpy as np +import pandas as pd +import torch +from torch.distributions import MultivariateNormal +from torch.utils.data import DataLoader, TensorDataset +import matplotlib.pyplot as plt +from sklearn.decomposition import PCA, IncrementalPCA +from sklearn.manifold import TSNE +from sklearn.cluster import KMeans +from sklearn.cluster import MiniBatchKMeans +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import silhouette_score, adjusted_rand_score, adjusted_mutual_info_score, davies_bouldin_score +import seaborn as sns +from PIL import Image # Import the Image module + + +def get_data_paths(data_format, is_linux=False, is_hpc=False): + if is_linux: + base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch" + labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn" + saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis" + elif is_hpc: + base_path = "/gpfs/scratchfs1/kic14002/doh16101" + labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005" + saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" + else: + # R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch + base_path = "R:\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch" + labels_base_path = "R:\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn" + saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case" + if data_format == 'csv': + data_path = os.path.join(base_path, "TFS_csv") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + elif data_format == 'png': + data_path = os.path.join(base_path, "TFS_plots") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + else: + raise ValueError("Invalid data format. Choose 'csv' or 'png.") + return data_path, labels_path, saving_path + +class CustomDataset(Dataset): + def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False): + self.data_path = data_path + self.labels_path = labels_path + self.UIDs = UIDs + self.standardize = standardize + self.data_format = data_format + self.read_all_labels = read_all_labels + self.transforms = ToTensor() + self.refresh_dataset() + + def refresh_dataset(self): + # Extract unique segment names and their corresponding labels + self.segment_names, self.labels = self.extract_segment_names_and_labels() + + def add_uids(self, new_uids): + # Ensure new UIDs are unique and not already in the dataset + unique_new_uids = [uid for uid in new_uids if uid not in self.UIDs] + + # Add unique new UIDs and refresh the dataset + self.UIDs.extend(unique_new_uids) + self.refresh_dataset() + + def __len__(self): + return len(self.segment_names) + + def __getitem__(self, idx): + segment_name = self.segment_names[idx] + label = self.labels[segment_name] + + # Load data on-the-fly based on the segment_name + time_freq_tensor = self.load_data(segment_name) + + return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name} + + def extract_segment_names_and_labels(self): + segment_names = [] + labels = {} + + for UID in self.UIDs: + label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv") + if os.path.exists(label_file): + label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) + label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) + for idx, segment_name in enumerate(label_segment_names): + label_val = label_data['label'].values[idx] + if self.read_all_labels: + # Assign -1 if label is not in [0, 1, 2, 3] + labels[segment_name] = label_val if label_val in [0, 1, 2, 3] else -1 + if segment_name not in segment_names: + segment_names.append(segment_name) + else: + # Only add segments with labels in [0, 1, 2, 3] + if label_val in [0, 1, 2, 3] and segment_name not in segment_names: + segment_names.append(segment_name) + labels[segment_name] = label_val + + return segment_names, labels + + def load_data(self, segment_name): + data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0]) + seg_path = os.path.join(data_path_UID, segment_name + '_filt_STFT.csv') + + try: + if self.data_format == 'csv' and seg_path.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif self.data_format == 'png' and seg_path.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + raise ValueError("Unsupported file format") + + if self.standardize: + time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data + + return time_freq_tensor.clone() + + except Exception as e: + print(f"Error processing segment: {segment_name}. Exception: {str(e)}") + return torch.zeros((1, 128, 128)) # Return zeros in case of an error + + def standard_scaling(self, data): + scaler = StandardScaler() + data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) + return torch.Tensor(data) + +def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=True, drop_last=False, num_workers=4): + dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=num_workers) + return dataloader + +# To validate the len of the dataloader +# number of examples +# len(train_loader) +# number of batches +# len(train_loader) + +# ====== Load the per subject arrythmia summary ====== +df_summary = pd.read_csv(r'\\grove.ad.uconn.edu\research\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv') +df_summary['UID'] = df_summary['UID'].astype(str).str.zfill(3) + +df_summary['sample_nonAF'] = df_summary['NSR'] + df_summary['PACPVC'] + df_summary['SVT'] +df_summary['sample_AF'] = df_summary['AF'] + +df_summary['sample_nonAF_ratio'] = df_summary['sample_nonAF'] / (df_summary['sample_AF'] + df_summary['sample_nonAF']) + +all_UIDs = df_summary['UID'].unique() +# ==================================================== +# ====== AF trial separation ====== +# R:\ENGR_Chon\Dong\Numbers\Pulsewatch_numbers\Fahimeh_CNNED_general_ExpertSystemwApplication\tbl_file_name\TrainingSet_final_segments +AF_trial_Fahimeh_train = ['402','410'] +AF_trial_Fahimeh_test = ['301', '302', '305', '306', '307', '310', '311', + '312', '318', '319', '320', '321', '322', '324', + '325', '327', '329', '400', '406', '407', '409', + '414'] +AF_trial_Fahimeh_did_not_use = ['405', '413', '415', '416', '420', '421', '422', '423'] +AF_trial_paroxysmal_AF = ['408','419'] + +AF_trial_train = AF_trial_Fahimeh_train +AF_trial_test = AF_trial_Fahimeh_test +AF_trial_unlabeled = AF_trial_Fahimeh_did_not_use + AF_trial_paroxysmal_AF +print(f'AF trial: {len(AF_trial_train)} training subjects {AF_trial_train}') +print(f'AF trial: {len(AF_trial_test)} testing subjects {AF_trial_test}') +print(f'AF trial: {len(AF_trial_unlabeled)} unlabeled subjects {AF_trial_unlabeled}') +# ================================= +# === Clinical trial AF subjects separation === +clinical_trial_AF_subjects = ['005', '017', '026', '051', '075', '082'] + +remaining_UIDs = [] +count_NSR = [] +import math +for index, row in df_summary.iterrows(): + UID = row['UID'] + this_NSR = row['sample_nonAF'] + if math.isnan(this_NSR): + # There is no segment in this subject, skip this UID. + print(f'---------UID {UID} has no segments.------------') + continue + if UID not in AF_trial_train and UID not in AF_trial_test and UID not in clinical_trial_AF_subjects \ + and not UID[0] == '3' and not UID[0] == '4': + remaining_UIDs.append(UID) + count_NSR.append(this_NSR) + +from numpy import random +random.seed(seed=42) +from numpy.random import choice +list_of_candidates = remaining_UIDs +number_of_items_to_pick = round(len(list_of_candidates) * 0.15) # 10% labeled for training, 5% for testing. +temp_sum = sum(count_NSR) +probability_distribution = [x/temp_sum for x in count_NSR] +probability_distribution = [(1-x/temp_sum)/ (len(count_NSR)-1) for x in count_NSR]# Subjects with fewer segments have higher chance to be selected. Make sure the sum is one. +draw = choice(list_of_candidates, number_of_items_to_pick, + p=probability_distribution, replace=False) + +clinical_trial_train = list(draw[:round(len(list_of_candidates) * 0.1)]) +clinical_trial_test_nonAF = list(draw[round(len(list_of_candidates) * 0.1):]) +clinical_trial_test_temp = clinical_trial_test_nonAF + clinical_trial_AF_subjects +clinical_trial_test = [] +for UID in clinical_trial_test_temp: + # UID 051 and maybe other UIDs had no segments (unknown reason). + if UID in all_UIDs: + clinical_trial_test.append(UID) + +clinical_trial_unlabeled = [] +for UID in all_UIDs: + if UID not in clinical_trial_train and UID not in clinical_trial_test and not UID[0] == '3' and not UID[0] == '4': + clinical_trial_unlabeled.append(UID) +print(f'Clinical trial: selected {len(clinical_trial_train)} UIDs for training {clinical_trial_train}') +print(f'Clinical trial: selected {len(clinical_trial_test)} UIDs for testing {clinical_trial_test}') +print(f'Clinical trial: selected {len(clinical_trial_unlabeled)} UIDs for unlabeled {clinical_trial_unlabeled}') + +is_linux = False # Set to True if running on Linux, False if on Windows +is_hpc = False # Set to True if running on an HPC, False if on Windows +data_format = 'csv' # Choose 'csv' or 'png' +data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) + +# clinical_trial_train = [clinical_trial_train[0]] +# clinical_trial_test = [clinical_trial_test[0]] +# clinical_trial_unlabeled = clinical_trial_unlabeled[0:4] + +batch_size = 512 +train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True) +val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True) +test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True) + +# data_iter = iter(train_loader) +# data = next(data_iter) +# data["label"] + + +''' +Key Points: +Initialization: The script starts by initializing the data loaders for training, validation, and testing. +Initial Training: The model is initially trained on the training set. +Active Learning Loop: During each iteration of active learning, the script: +Performs uncertainty sampling on the validation set. +Updates the training loader with the uncertain samples. +Retrains the model with this updated training set. +Final Evaluation: After the active learning iterations, the model is evaluated on the test set to assess its performance on unseen data. +Notes: +Data Preprocessing: The script assumes that train_batch['data'] from your data loaders is already appropriately preprocessed for your model. If additional preprocessing is required, make sure to include it. +Active Learning Iterations: The number of active learning iterations (active_learning_iterations) and the number of samples per iteration (n_samples) are parameters you can adjust based on your scenario and dataset. +Test Evaluation: The final evaluation on the test set gives you an understanding of how well the model generalizes to new, unseen data. +''' +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +import os +import pandas as pd +from sklearn.preprocessing import StandardScaler +from sklearn.manifold import TSNE +import gpytorch +from gpytorch.mlls import VariationalELBO +from gpytorch.kernels import RBFKernel, PeriodicKernel, ScaleKernel +from gpytorch.models import ExactGP +from gpytorch.mlls import ExactMarginalLogLikelihood +from gpytorch.means import ConstantMean +from gpytorch.distributions import MultivariateNormal +from sklearn.cluster import KMeans +from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score + +# Assuming CustomDataset and data loader functions are defined as provided + +def map_samples_to_uids(uncertain_sample_indices, dataset): + """ + Maps indices of uncertain samples back to their corresponding segment names or UIDs. + + Args: + - uncertain_sample_indices: Indices of the uncertain samples in the dataset. + - dataset: The dataset object which contains the mapping of segment names and UIDs. + + Returns: + - List of UIDs or segment names corresponding to the uncertain samples. + """ + return [dataset.segment_names[i] for i in uncertain_sample_indices] + +def apply_tsne(data, n_components=2): + n_samples = data.shape[0] + perplexity = min(30, n_samples - 1) # Ensure perplexity is less than the number of samples + tsne = TSNE(n_components=n_components, perplexity=perplexity) + return tsne.fit_transform(data) + +import torch +import gpytorch +from gpytorch.kernels import SpectralMixtureKernel +from gpytorch.models import AbstractVariationalGP +from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy +from gpytorch.distributions import MultivariateNormal + +import numpy as np +import torch +import gpytorch +import pandas as pd +from sklearn.cluster import KMeans +from torch.utils.data import DataLoader +from tqdm import tqdm +import matplotlib.pyplot as plt +from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, precision_recall_fscore_support +from sklearn.preprocessing import label_binarize, StandardScaler +from sklearn.manifold import TSNE + +num_latents = 6 # This should match the complexity of your data or the number of tasks +num_tasks = 4 # This should match the number of output classes or tasks +num_inducing_points = 50 # This is independent and should be sufficient for the input space + +class MultitaskGPModel(gpytorch.models.ApproximateGP): + def __init__(self): + # Let's use a different set of inducing points for each latent function + inducing_points = torch.rand(num_latents, num_inducing_points, 128 * 128) # Assuming flattened 128x128 images + + # We have to mark the CholeskyVariationalDistribution as batch + # so that we learn a variational distribution for each task + variational_distribution = gpytorch.variational.CholeskyVariationalDistribution( + inducing_points.size(-2), batch_shape=torch.Size([num_latents]) + ) + + # We have to wrap the VariationalStrategy in a LMCVariationalStrategy + # so that the output will be a MultitaskMultivariateNormal rather than a batch output + variational_strategy = gpytorch.variational.LMCVariationalStrategy( + gpytorch.variational.VariationalStrategy( + self, inducing_points, variational_distribution, learn_inducing_locations=True + ), + num_tasks=num_tasks, + num_latents=num_latents, + latent_dim=-1 + ) + + super().__init__(variational_strategy) + + # The mean and covariance modules should be marked as batch + # so we learn a different set of hyperparameters + self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents])) + self.covar_module = gpytorch.kernels.ScaleKernel( + gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])), + batch_shape=torch.Size([num_latents]) + ) + + def forward(self, x): + # The forward function should be written as if we were dealing with each output + # dimension in batch + # Ensure x is correctly shaped. It should have the same last dimension size as inducing_points + # x should be reshaped or sliced to have the shape [?, 1] where ? can be any size + # For example, if x originally has shape [N, D], and D != 1, you need to modify x accordingly + # print(f"Input shape: {x.shape}") + # x = x.view(x.size(0), -1) # Flattening the images + # print(f"Input shape after flattening: {x.shape}") # Debugging input shape + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + + # Debugging: Print shapes of intermediate outputs + # print(f"Mean shape: {mean_x.shape}, Covariance shape: {covar_x.shape}") + latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) + # print(f"Latent prediction shape: {latent_pred.mean.shape}, {latent_pred.covariance_matrix.shape}") + + return latent_pred + +# Usage +# model = MultitaskGPModel() +# likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=2, num_classes=4, mixing_weights = False) + +from tqdm import tqdm + +# Initialize result storage +results = { + 'train_loss': [], + 'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []}, + 'test_metrics': None # This will be filled in with the final test metrics +} + +import torch +import gpytorch +from gpytorch.likelihoods import SoftmaxLikelihood +from gpytorch.functions import log_normal_cdf + +def train_gp_model(train_x, train_y, val_loader, num_iterations=50, n_classes=4, patience=10, checkpoint_path='model_checkpoint_full.pt'): + model = MultitaskGPModel().to(device) + likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=4, num_classes=4).to(device) + model.train() + likelihood.train() + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0)) + + best_val_loss = float('inf') + epochs_no_improve = 0 + + for i in tqdm(range(num_iterations), desc='Training', unit='iter', leave=False): + optimizer.zero_grad() + output = model(train_x) + loss = -mll(output, train_y) + scalar_loss = loss.sum() if loss.numel() > 1 else loss + scalar_loss.backward() + optimizer.step() + + # Validation step + model.eval() + likelihood.eval() + with torch.no_grad(): + val_loss = 0.0 + for val_batch in val_loader: + val_x, val_y = val_batch['data'].view(val_batch['data'].size(0), -1).to(device), val_batch['label'].to(device) + val_output = model(val_x) + val_loss += -mll(val_output, val_y).item() + val_loss /= len(val_loader) + + model.train() + likelihood.train() + + # Early stopping and checkpointing based on validation loss + if val_loss < best_val_loss: + best_val_loss = val_loss + epochs_no_improve = 0 + torch.save({'model_state_dict': model.state_dict(), + 'likelihood_state_dict': likelihood.state_dict(), + 'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path) + else: + epochs_no_improve += 1 + if epochs_no_improve == patience: + print(f"Early stopping triggered at iteration {i+1}") + break + + # Load the best model before return + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model_state_dict']) + likelihood.load_state_dict(checkpoint['likelihood_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + return model, likelihood + +def parse_classification_report(report): + """Parse a classification report into a dictionary of metrics.""" + lines = report.split('\n') + main_metrics = lines[-2].split() + + # Assuming the last line is like "accuracy: x macro avg y1 y2 y3 y4" + return { + 'precision': float(main_metrics[3]), + 'recall': float(main_metrics[4]), + 'f1': float(main_metrics[5]), + 'auc_roc': None # AUC-ROC is not part of the classification report by default + } + +from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score +from sklearn.preprocessing import label_binarize +from sklearn.metrics import auc +from sklearn.metrics import roc_curve +from sklearn.metrics import classification_report +from sklearn.metrics import precision_recall_fscore_support + +def evaluate_model_on_all_data(model, likelihood, data_loader, device, n_classes): + model.eval() + likelihood.eval() + + all_predicted_labels = [] + all_test_labels = [] + + with torch.no_grad(), gpytorch.settings.fast_pred_var(): + for i, batch in enumerate(data_loader): + test_data = batch['data'].view(batch['data'].size(0), -1).to(device) + test_labels = batch['label'].to(device) + # print(f"Test data shape before t-SNE: {test_data.shape}") + + predictions = likelihood(model(test_data)).mean + # Debugging - check shape of predictions + # print(f"Predictions shape: {predictions.shape}") + predicted_labels = predictions.argmax(dim=0) + + # Add debugging information + # print(f"Batch {i}: Predicted Labels Shape: {predicted_labels.shape}, Actual Labels Shape: {test_labels.shape}") + + all_predicted_labels.append(predicted_labels.cpu().numpy()) + all_test_labels.append(test_labels.numpy()) + + # Debug the accumulation of labels + current_predicted = np.concatenate(all_predicted_labels, axis=0) + current_actual = np.concatenate(all_test_labels, axis=0) + # print(f"After Batch {i}: Accumulated Predicted Labels: {current_predicted.shape[0]}, Accumulated Actual Labels: {current_actual.shape[0]}") + + # Concatenate all batch results + all_predicted_labels = np.concatenate(all_predicted_labels, axis=0) + all_test_labels = np.concatenate(all_test_labels, axis=0) + + # Final check + # print(f"Final: Total Predicted Labels: {all_predicted_labels.shape[0]}, Total Actual Labels: {all_test_labels.shape[0]}") + + # Verify if the shapes match before proceeding to calculate metrics + if all_predicted_labels.shape[0] != all_test_labels.shape[0]: + raise ValueError("Mismatch in the number of samples between predicted and actual labels") + + # Compute overall evaluation metrics + precision, recall, f1, _ = precision_recall_fscore_support(all_test_labels, all_predicted_labels, average='macro') + # For AUC-ROC, you need the predicted probabilities and true labels in a one-hot encoded format + test_labels_one_hot = label_binarize(all_test_labels, classes=np.arange(n_classes)) + auc_roc = roc_auc_score(test_labels_one_hot, predictions.softmax(dim=-1).cpu().numpy(), multi_class='ovr') + return { + 'precision': precision, + 'recall': recall, + 'f1': f1, + 'auc_roc': auc_roc + } + +import random + +def stochastic_uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_batches, n_components=2): + gp_model.eval() + gp_likelihood.eval() + uncertain_sample_indices = [] + sampled_batches = random.sample(list(val_loader), n_batches) # Randomly sample n_batches from val_loader + + with torch.no_grad(): + for batch in sampled_batches: + # reduced_data = apply_tsne(batch['data'].reshape(batch['data'].size(0), -1), n_components=n_components) + # reduced_data_tensor = torch.Tensor(reduced_data).to(device) + reduced_data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device) + predictions = gp_likelihood(gp_model(reduced_data_tensor)) + var = predictions.variance + top_indices = torch.argsort(-var.flatten())[:n_samples] + uncertain_sample_indices.extend(top_indices.cpu().numpy()) + + return uncertain_sample_indices[:n_samples] + +# def uncertainty_sampling(gp_model, gp_likelihood, val_loader, n_samples, n_components=2): +# gp_model.eval() +# gp_likelihood.eval() +# uncertain_sample_indices = [] +# with torch.no_grad(): +# for batch_idx, batch in tqdm(enumerate(val_loader), desc='Uncertainty Sampling', unit='batch'): +# reduced_data_tensor = batch['data'].view(batch['data'].size(0), -1).to(device) +# predictions = gp_likelihood(gp_model(reduced_data_tensor)) +# var = predictions.variance +# top_indices = torch.argsort(-var.flatten())[:n_samples] +# batch_uncertain_indices = [batch_idx * val_loader.batch_size + idx for idx in top_indices] +# uncertain_sample_indices.extend(batch_uncertain_indices) +# return uncertain_sample_indices[:n_samples] + +def label_samples(uncertain_samples, validation_data): + labels = [validation_data[sample_id]['label'] for sample_id in uncertain_samples] + return uncertain_samples, labels + +def update_train_loader_with_uncertain_samples(current_train_loader, new_sample_indices, data_path, labels_path, batch_size, standardize=False, data_format='csv', read_all_labels=True): + # Extract current UIDs from the current_train_loader + current_dataset = current_train_loader.dataset + # Map new_samples back to their corresponding segment names or UIDs + new_uids = map_samples_to_uids(new_sample_indices, current_dataset) + # Add new UIDs to the current dataset and refresh it + current_dataset.add_uids(new_uids) + # Create new DataLoader with the updated dataset + updated_train_loader = DataLoader(current_dataset, batch_size=batch_size, shuffle=False) + return updated_train_loader + +def plot_training_performance(train_loss, validation_metrics): + epochs = range(1, len(train_loss) + 1) + + # Plot training loss + plt.figure(figsize=(14, 6)) + plt.subplot(1, 2, 1) + plt.plot(epochs, train_loss, 'b-', label='Training Loss') + plt.title('Training Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend() + + # Plot validation metrics + plt.subplot(1, 2, 2) + plt.plot(epochs, validation_metrics['precision'], 'r-', label='Precision') + plt.plot(epochs, validation_metrics['recall'], 'g-', label='Recall') + plt.plot(epochs, validation_metrics['f1'], 'b-', label='F1 Score') + plt.plot(epochs, validation_metrics['auc_roc'], 'y-', label='AUC-ROC') + plt.title('Validation Metrics') + plt.xlabel('Epochs') + plt.ylabel('Metrics') + plt.legend() + + plt.tight_layout() + plt.show() + +def plot_results(results): + plt.figure(figsize=(12, 5)) + plt.subplot(1, 2, 1) + plt.plot(results['train_loss'], label='Train Loss') + plt.title('Training Loss Over Time') + plt.legend() + + plt.subplot(1, 2, 2) + for metric in ['precision', 'recall', 'f1']: + plt.plot(results['validation_metrics'][metric], label=metric.title()) + plt.title('Validation Metrics Over Time') + plt.legend() + plt.show() + + test_metrics = results['test_metrics'] + print("Test Metrics:") + print(f"Precision: {test_metrics['precision']}") + print(f"Recall: {test_metrics['recall']}") + print(f"F1 Score: {test_metrics['f1']}") + print(f"AUC-ROC: {test_metrics['auc_roc']}") + +from sklearn.cluster import KMeans + +# K-Means Validation Function +def kmeans_validation(model, data_loader, n_clusters, device): + model.eval() + all_data = [] + all_predictions = [] + with torch.no_grad(): + for batch in data_loader: + data = batch['data'].view(batch['data'].size(0), -1).to(device) + labels = batch['label'].to(device) + predictions = model(data).mean.argmax(dim=-1).cpu().numpy() + all_data.extend(data.cpu().numpy()) + all_predictions.extend(predictions) + + all_data = np.array(all_data) + all_predictions = np.array(all_predictions) + + # Perform KMeans clustering + kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(all_data) + cluster_labels = kmeans.labels_ + + # Check label consistency within each cluster + for i in range(n_clusters): + cluster_indices = np.where(cluster_labels == i)[0] + cluster_pred_labels = all_predictions[cluster_indices] + most_common_label = np.bincount(cluster_pred_labels).argmax() + # Compare most_common_label with actual labels in the cluster + # Adjust the model's predictions if necessary + + return kmeans + +from sklearn.cluster import MiniBatchKMeans + +def run_minibatch_kmeans(data_loader, n_clusters, device, batch_size=100): + # Initialize MiniBatchKMeans + minibatch_kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=0, batch_size=batch_size) + + # Iterate through data_loader and fit MiniBatchKMeans + for batch in data_loader: + data = batch['data'].view(batch['data'].size(0), -1).to(device).cpu().numpy() + minibatch_kmeans.partial_fit(data) + + return minibatch_kmeans + +# def compare_kmeans_gp_predictions(kmeans_model, gp_model, data_loader, device): +# # Compare K-Means with GP model predictions +# all_data, all_labels = [], [] +# for batch in data_loader: +# data = batch['data'].view(batch['data'].size(0), -1).to(device) +# labels = batch['label'].to(device) +# gp_predictions = gp_model(data).mean.argmax(dim=0).cpu().numpy() +# kmeans_predictions = kmeans_model.predict(data.cpu().numpy()) +# all_labels.append(labels.cpu().numpy()) +# all_data.append((gp_predictions, kmeans_predictions)) +# return all_data, np.concatenate(all_labels) + +def stochastic_compare_kmeans_gp_predictions(kmeans_model, gp_model, data_loader, n_batches, device): + all_data, all_labels = [], [] + sampled_batches = random.sample(list(data_loader), n_batches) # Randomly sample n_batches from data_loader + + for batch in sampled_batches: + data = batch['data'].view(batch['data'].size(0), -1).to(device) + labels = batch['label'].to(device) + gp_predictions = gp_model(data).mean.argmax(dim=0).cpu().numpy() + kmeans_predictions = kmeans_model.predict(data.cpu().numpy()) + all_labels.append(labels.cpu().numpy()) + all_data.append((gp_predictions, kmeans_predictions)) + + return all_data, np.concatenate(all_labels) + +import matplotlib.pyplot as plt +from sklearn.metrics import confusion_matrix +import seaborn as sns + +def plot_comparative_results(gp_vs_kmeans_data, original_labels): + fig, axes = plt.subplots(1, 2, figsize=(14, 7)) + + # Plot 1: Confusion Matrix for GP Predictions vs Original Labels + gp_predictions = [pair[0] for pair in gp_vs_kmeans_data] + gp_predictions = np.concatenate(gp_predictions) + cm_gp = confusion_matrix(original_labels, gp_predictions) + sns.heatmap(cm_gp, annot=True, ax=axes[0], fmt='g') + axes[0].set_title('GP Model Predictions vs Original Labels') + axes[0].set_xlabel('Predicted Labels') + axes[0].set_ylabel('True Labels') + + # Plot 2: Confusion Matrix for K-Means Predictions vs Original Labels + kmeans_predictions = [pair[1] for pair in gp_vs_kmeans_data] + kmeans_predictions = np.concatenate(kmeans_predictions) + cm_kmeans = confusion_matrix(original_labels, kmeans_predictions) + sns.heatmap(cm_kmeans, annot=True, ax=axes[1], fmt='g') + axes[1].set_title('K-Means Predictions vs Original Labels') + axes[1].set_xlabel('Predicted Labels') + axes[1].set_ylabel('True Labels') + + plt.tight_layout() + plt.show() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Main execution +if __name__ == "__main__": + # Define the number of output features and classes first + n_classes = 4 # Assuming you have 4 classes + + # Initialize data loaders + train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=False, data_format='csv', read_all_labels=False, drop_last=True) + val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=False, data_format='csv', read_all_labels=True, drop_last=True) + test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=False, data_format='csv', read_all_labels=True, drop_last=True) + + kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device) + + # Initialize result storage + results = { + 'train_loss': [], + 'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []}, + 'test_metrics': None # This will be filled in with the final test metrics + } + + # Initial model training + for train_batch in train_loader: + train_x = train_batch['data'].view(train_batch['data'].size(0), -1).to(device) + train_y = train_batch['label'].to(device) + model, likelihood = train_gp_model(train_x, train_y, val_loader, num_iterations=10, n_classes=n_classes) + + active_learning_iterations = 10 + n_samples = batch_size # Number of uncertain samples to accumulate + for iteration in tqdm(range(active_learning_iterations), desc='Active Learning', unit='iteration', leave=True): + uncertain_sample_indices = stochastic_uncertainty_sampling(model, likelihood, val_loader, n_samples, n_batches=5, n_components=2) + + # Accumulate indices of uncertain samples + accumulated_indices = [] + for idx in uncertain_sample_indices: + accumulated_indices.append(idx) + + # Update the training loader with indices of uncertain samples + train_loader = update_train_loader_with_uncertain_samples(train_loader, accumulated_indices, data_path, labels_path, batch_size) + + # Re-train the model with the updated train_loader + for train_batch in tqdm(train_loader, desc='Batch Training', leave=False): + train_x = train_batch['data'].view(train_batch['data'].size(0), -1).to(device) # Flatten the image + train_y = train_batch['label'].to(device) + model, likelihood = train_gp_model(train_x, train_y, val_loader, num_iterations=10, n_classes=n_classes) + val_metrics = evaluate_model_on_all_data(model, likelihood, val_loader, device, n_classes) + for metric in ['precision', 'recall', 'f1', 'auc_roc']: + results['validation_metrics'][metric].append(val_metrics[metric]) + + # Compare K-Means with GP model predictions after retraining + gp_vs_kmeans_data, original_labels = stochastic_compare_kmeans_gp_predictions(kmeans_model, model, train_loader, n_batches=5, device=device) + plot_comparative_results(gp_vs_kmeans_data, original_labels) + + plot_training_performance(results['train_loss'], results['validation_metrics']) + + # Final evaluation on test set + classification_result = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes=n_classes) + # Store test metrics + results['test_metrics'] = classification_result + # Now results dictionary is ready to be used for plotting + plot_results(results) + # You might also want to print or log the final test metrics + print("Final Test Metrics:", results['test_metrics']) +