From 6edf0c39d27c5999245d2a535c650476dd1a9099 Mon Sep 17 00:00:00 2001 From: Luis Roberto Mercado Diaz Date: Tue, 24 Oct 2023 14:19:23 -0400 Subject: [PATCH] PNG analysis Adding a way to analyze time series or plots. --- project_1.py | 84 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 27 deletions(-) diff --git a/project_1.py b/project_1.py index 5a3d472..9c52fcf 100644 --- a/project_1.py +++ b/project_1.py @@ -15,34 +15,73 @@ from sklearn.decomposition import PCA from sklearn.metrics import silhouette_score, adjusted_rand_score import seaborn as sns +from PIL import Image # Import the Image module + +def get_data_paths(data_format, is_linux=False, is_hpc=False): + if is_linux: + base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch" + labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn" + saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis" + elif is_hpc: + base_path = "/gpfs/scratchfs1/kic14002/doh16101" + labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005" + saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" + else: + base_path = "R:\\ENGR_Chon\\Dong\\MATLAB_generate_results\\NIH_PulseWatch" + labels_base_path = "R:\\ENGR_Chon\\NIH_Pulsewatch_Database\\Adjudication_UConn" + saving_base_path = "R:\\ENGR_Chon\\Luis\\Research\\Casseys_case\\Project_1_analysis" + + if data_format == 'csv': + data_path = os.path.join(base_path, "TFS_csv") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + elif data_format == 'png': + data_path = os.path.join(base_path, "TFS_plots") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + else: + raise ValueError("Invalid data format. Choose 'csv' or 'png.") + + return data_path, labels_path, saving_path + +def load_data(data_path, labels_path, dataset_size=10, train=True, standardize=True, data_format='csv'): + if data_format not in ['csv', 'png']: + raise ValueError("Invalid data_format. Choose 'csv' or 'png'.") -def load_data(data_path, labels_path, dataset_size=10, train=True, standardize=True): - # Load data from the specified data_path dir_list_UID = os.listdir(data_path) UID_list = dir_list_UID[:dataset_size] if train else dir_list_UID[dataset_size:] - + X_data = [] segment_names = [] - + for UID in UID_list: data_path_UID = os.path.join(data_path, UID) dir_list_seg = os.listdir(data_path_UID) - + for seg in dir_list_seg[:50]: # Limiting to 50 segments seg_path = os.path.join(data_path_UID, seg) - time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) - time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + + if data_format == 'csv' and seg.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif data_format == 'png' and seg.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + continue # Skip other file formats + X_data.append(time_freq_tensor) segment_names.append(seg) # Store segment names - + X_data = torch.cat(X_data, 0) - + if standardize: X_data = standard_scaling(X_data) - + # Extract labels from CSV files labels = extract_labels(UID_list, labels_path) - + return X_data, segment_names, labels def extract_labels(UID_list, labels_path): @@ -269,24 +308,15 @@ def main(): is_linux = False # Set to True if running on Linux, False if on Windows is_hpc = False # Set to True if running on hpc, False if on Windows - if is_linux: - data_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch/TFS_csv" - labels_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn/final_attemp_4_1_Dong_Ohm" - saving_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis" - elif is_hpc: - data_path = "/gpfs/scratchfs1/kic14002/doh16101/TFS_csv" - labels_path = "/gpfs/scratchfs1/hfp14002/lrm22005/final_attemp_4_1_Dong_Ohm" - saving_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" - else: - data_path = r"R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch\TFS_csv" - labels_path = r"R:\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm" - saving_path = r"R:\ENGR_Chon\Luis\Research\Casseys_case\Project_1_analysis" + data_format = 'csv' # Choose 'csv' or 'png' + + data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) - train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=20, train=True) - test_data, _, _ = load_data(data_path, labels_path, dataset_size=30, train=False) + train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=20, train=True, data_format=data_format) + # test_data, _, _ = load_data(data_path, labels_path, dataset_size=30, train=False) train_dataloader = create_dataloader(train_data) - test_dataloader = create_dataloader(test_data) + # test_dataloader = create_dataloader(test_data) # Visualize random trends/segments visualize_trends(train_data, segment_names, num_plots=20) @@ -297,7 +327,7 @@ def main(): # Visualize the correlation matrix visualize_correlation_matrix(train_data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=saving_path) - visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=10, save_path=saving_path) + visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=20, save_path=saving_path) # visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path) # Perform MFVI for your data