-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from lrm22005/Luis
Luis
- Loading branch information
Showing
6 changed files
with
1,392 additions
and
238 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
|
||
error_log.txt | ||
progress.log | ||
__pycache__/GP_original_data.cpython-311.pyc | ||
model_checkpoint_tsne.pt | ||
model_checkpoint_full.pt | ||
simple_CNN.py | ||
VAE.py | ||
model_checkpoint.pt | ||
GP_original_data.py | ||
Attention_network.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,334 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Mon Dec 18 12:34:01 2023 | ||
@author: lrm22005 | ||
This approach keeps the checkpointing logic separate from the data loading logic, | ||
which is a good practice in terms of code organization and reusability. | ||
The should_save_checkpoint function is a placeholder for whatever logic you want to use to determine when to save a checkpoint. | ||
It could be based on time, number of batches processed, or any other criterion. | ||
Remember that while this method saves the progress in terms of batch index, | ||
it does not automatically save the state of your model, optimizer, or any other components of your training loop. You should handle those separately as needed. | ||
The code is designed to train a machine learning model (Gaussian Process) using an active learning approach. | ||
It heavily relies on custom functions like load_data_split_batched, train_gp_model, uncertainty_sampling, update_train_loader_with_uncertain_samples, evaluate_model_on_all_data, and plotting functions, which are not defined in this snippet. | ||
Active learning is used to iteratively select the most informative samples to improve the model iteratively. | ||
Checkpoints are used to save the state of the model at each iteration, allowing for recovery and continuation of training in case of interruption. | ||
The code evaluates model performance on both validation and test datasets, storing various metrics for analysis. | ||
""" | ||
import os | ||
import torch | ||
import pandas as pd | ||
import numpy as np | ||
from torch.utils.data import DataLoader, Dataset | ||
from torchvision.transforms import ToTensor | ||
from sklearn.preprocessing import StandardScaler | ||
from PIL import Image | ||
import pickle | ||
|
||
from tqdm import tqdm | ||
from GP_original_data import map_samples_to_uids, MultitaskGPModel, train_gp_model, parse_classification_report | ||
from GP_original_data import evaluate_model_on_all_data, uncertainty_sampling, label_samples | ||
from GP_original_data import update_train_loader_with_uncertain_samples, plot_training_performance, plot_results | ||
|
||
def get_data_paths(data_format, is_linux=False, is_hpc=False): | ||
if is_linux: | ||
base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch" | ||
labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn" | ||
saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis" | ||
elif is_hpc: | ||
base_path = "/gpfs/scratchfs1/kic14002/doh16101" | ||
labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005" | ||
saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" | ||
else: | ||
# R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch | ||
base_path = "R:\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch" | ||
labels_base_path = "R:\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn" | ||
saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case" | ||
if data_format == 'csv': | ||
data_path = os.path.join(base_path, "TFS_csv") | ||
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") | ||
saving_path = os.path.join(saving_base_path, "Project_1_analysis") | ||
elif data_format == 'png': | ||
data_path = os.path.join(base_path, "TFS_plots") | ||
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") | ||
saving_path = os.path.join(saving_base_path, "Project_1_analysis") | ||
else: | ||
raise ValueError("Invalid data format. Choose 'csv' or 'png.") | ||
return data_path, labels_path, saving_path | ||
|
||
class CustomDataset(Dataset): | ||
def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv', read_all_labels=False): | ||
self.data_path = data_path | ||
self.labels_path = labels_path | ||
self.UIDs = UIDs | ||
self.standardize = standardize | ||
self.data_format = data_format | ||
self.read_all_labels = read_all_labels | ||
self.transforms = ToTensor() | ||
self.refresh_dataset() | ||
|
||
def refresh_dataset(self): | ||
# Extract unique segment names and their corresponding labels | ||
self.segment_names, self.labels = self.extract_segment_names_and_labels() | ||
|
||
def add_uids(self, new_uids): | ||
# Ensure new UIDs are unique and not already in the dataset | ||
unique_new_uids = [uid for uid in new_uids if uid not in self.UIDs] | ||
|
||
# Add unique new UIDs and refresh the dataset | ||
self.UIDs.extend(unique_new_uids) | ||
self.refresh_dataset() | ||
|
||
def __len__(self): | ||
return len(self.segment_names) | ||
|
||
def __getitem__(self, idx): | ||
segment_name = self.segment_names[idx] | ||
label = self.labels[segment_name] | ||
|
||
# Load data on-the-fly based on the segment_name | ||
time_freq_tensor = self.load_data(segment_name) | ||
|
||
return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name} | ||
|
||
def extract_segment_names_and_labels(self): | ||
segment_names = [] | ||
labels = {} | ||
|
||
for UID in self.UIDs: | ||
label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv") | ||
if os.path.exists(label_file): | ||
label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) | ||
label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) | ||
for idx, segment_name in enumerate(label_segment_names): | ||
label_val = label_data['label'].values[idx] | ||
if self.read_all_labels: | ||
# Assign -1 if label is not in [0, 1, 2, 3] | ||
labels[segment_name] = label_val if label_val in [0, 1, 2, 3] else -1 | ||
if segment_name not in segment_names: | ||
segment_names.append(segment_name) | ||
else: | ||
# Only add segments with labels in [0, 1, 2, 3] | ||
if label_val in [0, 1, 2, 3] and segment_name not in segment_names: | ||
segment_names.append(segment_name) | ||
labels[segment_name] = label_val | ||
|
||
return segment_names, labels | ||
|
||
def load_data(self, segment_name): | ||
data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0]) | ||
seg_path = os.path.join(data_path_UID, segment_name + '_filt_STFT.csv') | ||
|
||
try: | ||
if self.data_format == 'csv' and seg_path.endswith('.csv'): | ||
time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) | ||
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) | ||
elif self.data_format == 'png' and seg_path.endswith('.png'): | ||
img = Image.open(seg_path) | ||
img_data = np.array(img) | ||
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) | ||
else: | ||
raise ValueError("Unsupported file format") | ||
|
||
if self.standardize: | ||
time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data | ||
|
||
return time_freq_tensor.clone() | ||
|
||
except Exception as e: | ||
print(f"Error processing segment: {segment_name}. Exception: {str(e)}") | ||
return torch.zeros((1, 128, 128)) # Return zeros in case of an error | ||
|
||
def standard_scaling(self, data): | ||
scaler = StandardScaler() | ||
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) | ||
return torch.Tensor(data) | ||
|
||
def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv', read_all_labels=True, drop_last=False): | ||
dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format, read_all_labels) | ||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last) | ||
return dataloader | ||
|
||
class CheckpointManager: | ||
def __init__(self, checkpoint_dir): | ||
self.checkpoint_dir = checkpoint_dir # Store the directory path for checkpoints | ||
if not os.path.exists(checkpoint_dir): # Check if the directory exists | ||
os.makedirs(checkpoint_dir) # Create the directory if it does not exist | ||
|
||
def save_checkpoint(self, loader_name, iteration, additional_state): | ||
# Construct the checkpoint file path using the loader name | ||
checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl") | ||
checkpoint = { | ||
'iteration': iteration, # Store the current iteration | ||
'additional_state': additional_state # Store any additional state information | ||
} | ||
with open(checkpoint_path, 'wb') as f: # Open the file in write-binary mode | ||
pickle.dump(checkpoint, f) # Serialize the checkpoint dictionary to the file | ||
|
||
def load_checkpoint(self, loader_name): | ||
# Construct the checkpoint file path using the loader name | ||
checkpoint_path = os.path.join(self.checkpoint_dir, f"{loader_name}_checkpoint.pkl") | ||
try: | ||
with open(checkpoint_path, 'rb') as f: # Open the file in read-binary mode | ||
return pickle.load(f) # Deserialize the checkpoint file and return it | ||
except FileNotFoundError: # Handle the case where the checkpoint file does not exist | ||
return None # Return None if the file is not found | ||
|
||
|
||
# ====== Load the per subject arrythmia summary ====== | ||
df_summary = pd.read_csv(r'\\grove.ad.uconn.edu\research\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv') | ||
df_summary['UID'] = df_summary['UID'].astype(str).str.zfill(3) | ||
|
||
df_summary['sample_nonAF'] = df_summary['NSR'] + df_summary['PACPVC'] + df_summary['SVT'] | ||
df_summary['sample_AF'] = df_summary['AF'] | ||
|
||
df_summary['sample_nonAF_ratio'] = df_summary['sample_nonAF'] / (df_summary['sample_AF'] + df_summary['sample_nonAF']) | ||
|
||
all_UIDs = df_summary['UID'].unique() | ||
# ==================================================== | ||
# ====== AF trial separation ====== | ||
# R:\ENGR_Chon\Dong\Numbers\Pulsewatch_numbers\Fahimeh_CNNED_general_ExpertSystemwApplication\tbl_file_name\TrainingSet_final_segments | ||
AF_trial_Fahimeh_train = ['402','410'] | ||
AF_trial_Fahimeh_test = ['301', '302', '305', '306', '307', '310', '311', | ||
'312', '318', '319', '320', '321', '322', '324', | ||
'325', '327', '329', '400', '406', '407', '409', | ||
'414'] | ||
AF_trial_Fahimeh_did_not_use = ['405', '413', '415', '416', '420', '421', '422', '423'] | ||
AF_trial_paroxysmal_AF = ['408','419'] | ||
|
||
AF_trial_train = AF_trial_Fahimeh_train | ||
AF_trial_test = AF_trial_Fahimeh_test | ||
AF_trial_unlabeled = AF_trial_Fahimeh_did_not_use + AF_trial_paroxysmal_AF | ||
print(f'AF trial: {len(AF_trial_train)} training subjects {AF_trial_train}') | ||
print(f'AF trial: {len(AF_trial_test)} testing subjects {AF_trial_test}') | ||
print(f'AF trial: {len(AF_trial_unlabeled)} unlabeled subjects {AF_trial_unlabeled}') | ||
# ================================= | ||
# === Clinical trial AF subjects separation === | ||
clinical_trial_AF_subjects = ['005', '017', '026', '051', '075', '082'] | ||
|
||
remaining_UIDs = [] | ||
count_NSR = [] | ||
import math | ||
for index, row in df_summary.iterrows(): | ||
UID = row['UID'] | ||
this_NSR = row['sample_nonAF'] | ||
if math.isnan(this_NSR): | ||
# There is no segment in this subject, skip this UID. | ||
print(f'---------UID {UID} has no segments.------------') | ||
continue | ||
if UID not in AF_trial_train and UID not in AF_trial_test and UID not in clinical_trial_AF_subjects \ | ||
and not UID[0] == '3' and not UID[0] == '4': | ||
remaining_UIDs.append(UID) | ||
count_NSR.append(this_NSR) | ||
|
||
from numpy import random | ||
random.seed(seed=42) | ||
from numpy.random import choice | ||
list_of_candidates = remaining_UIDs | ||
number_of_items_to_pick = round(len(list_of_candidates) * 0.15) # 10% labeled for training, 5% for testing. | ||
temp_sum = sum(count_NSR) | ||
probability_distribution = [x/temp_sum for x in count_NSR] | ||
probability_distribution = [(1-x/temp_sum)/ (len(count_NSR)-1) for x in count_NSR]# Subjects with fewer segments have higher chance to be selected. Make sure the sum is one. | ||
draw = choice(list_of_candidates, number_of_items_to_pick, | ||
p=probability_distribution, replace=False) | ||
|
||
clinical_trial_train = list(draw[:round(len(list_of_candidates) * 0.1)]) | ||
clinical_trial_test_nonAF = list(draw[round(len(list_of_candidates) * 0.1):]) | ||
clinical_trial_test_temp = clinical_trial_test_nonAF + clinical_trial_AF_subjects | ||
clinical_trial_test = [] | ||
for UID in clinical_trial_test_temp: | ||
# UID 051 and maybe other UIDs had no segments (unknown reason). | ||
if UID in all_UIDs: | ||
clinical_trial_test.append(UID) | ||
|
||
clinical_trial_unlabeled = [] | ||
for UID in all_UIDs: | ||
if UID not in clinical_trial_train and UID not in clinical_trial_test and not UID[0] == '3' and not UID[0] == '4': | ||
clinical_trial_unlabeled.append(UID) | ||
print(f'Clinical trial: selected {len(clinical_trial_train)} UIDs for training {clinical_trial_train}') | ||
print(f'Clinical trial: selected {len(clinical_trial_test)} UIDs for testing {clinical_trial_test}') | ||
print(f'Clinical trial: selected {len(clinical_trial_unlabeled)} UIDs for unlabeled {clinical_trial_unlabeled}') | ||
|
||
# Global parameters related to the machine learning model | ||
num_latents = 6 # Define the number of latents | ||
num_tasks = 4 # Define the number of tasks | ||
num_inducing_points = 50 # Define the number of inducing points | ||
|
||
# Initialize a dictionary to store various metrics | ||
results = { | ||
'train_loss': [], | ||
'validation_metrics': {'precision': [], 'recall': [], 'f1': [], 'auc_roc': []}, | ||
'test_metrics': None # Placeholder for final test metrics | ||
} | ||
|
||
# Set the device to CUDA if available, otherwise use CPU | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# Start of the main execution block | ||
if __name__ == "__main__": | ||
# Set the number of classes for the model | ||
n_classes = 4 | ||
|
||
# Flags for running environment | ||
is_linux = False | ||
is_hpc = False | ||
data_format = 'csv' # Set the data format | ||
# Function call to get data paths based on the environment and format | ||
data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) | ||
|
||
# Define batch size for loading data | ||
batch_size = 512 | ||
# Load the training, validation, and test data | ||
train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True) | ||
val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True) | ||
test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size, standardize=True, data_format='csv', read_all_labels=False, drop_last=True) | ||
|
||
# Initialize the CheckpointManager | ||
checkpoint_manager = CheckpointManager(saving_path) | ||
|
||
# Attempt to load a training checkpoint | ||
train_checkpoint = checkpoint_manager.load_checkpoint('train') | ||
start_iteration = train_checkpoint['iteration'] if train_checkpoint else 0 | ||
|
||
# Active learning iterations loop | ||
active_learning_iterations = 10 | ||
n_samples = batch_size # Set the number of samples for uncertainty sampling | ||
for iteration in tqdm(range(start_iteration, active_learning_iterations), desc='Active Learning', unit='iteration'): | ||
|
||
# Training and active learning logic | ||
for train_batch in train_loader: | ||
train_x = train_batch['data'].view(train_batch['data'].size(0), -1).to(device) | ||
train_y = train_batch['label'].to(device) | ||
model, likelihood = train_gp_model(train_x, train_y, val_loader, num_iterations=10, n_classes=n_classes) | ||
# Save checkpoint at the end of each iteration | ||
|
||
uncertain_sample_indices = uncertainty_sampling(model, likelihood, val_loader, n_samples, n_components=2) | ||
accumulated_indices = [idx for idx in uncertain_sample_indices] | ||
train_loader = update_train_loader_with_uncertain_samples(train_loader, accumulated_indices, data_path, labels_path, batch_size) | ||
|
||
for train_batch in tqdm(train_loader, desc='Batch Training', leave=False): | ||
train_x = train_batch['data'].view(train_batch['data'].size(0), -1).to(device) | ||
train_y = train_batch['label'].to(device) | ||
model, likelihood = train_gp_model(train_x, train_y, val_loader, num_iterations=10, n_classes=n_classes) | ||
|
||
val_metrics = evaluate_model_on_all_data(model, likelihood, val_loader, device, n_classes) | ||
for metric in ['precision', 'recall', 'f1', 'auc_roc']: | ||
results['validation_metrics'][metric].append(val_metrics[metric]) | ||
|
||
# Save checkpoint at the end of each iteration | ||
additional_state = { | ||
'model_state': model.state_dict(), | ||
# Include other states like optimizer, scheduler, etc. | ||
} | ||
checkpoint_manager.save_checkpoint('train', iteration, additional_state) | ||
|
||
# Plot the training performance based on stored metrics | ||
plot_training_performance(results['train_loss'], results['validation_metrics']) | ||
|
||
# Final evaluation on test set | ||
classification_result = evaluate_model_on_all_data(model, likelihood, test_loader, device, n_classes=n_classes) | ||
results['test_metrics'] = classification_result | ||
plot_results(results) | ||
print("Final Test Metrics:", results['test_metrics']) |
Oops, something went wrong.