Skip to content

Commit

Permalink
PNG analysis
Browse files Browse the repository at this point in the history
Adding a way to analyze time series or plots.
  • Loading branch information
lrm22005 committed Oct 24, 2023
1 parent 0537059 commit 6edf0c3
Showing 1 changed file with 57 additions and 27 deletions.
84 changes: 57 additions & 27 deletions project_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,73 @@
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score, adjusted_rand_score
import seaborn as sns
from PIL import Image # Import the Image module

def get_data_paths(data_format, is_linux=False, is_hpc=False):
if is_linux:
base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch"
labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn"
saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis"
elif is_hpc:
base_path = "/gpfs/scratchfs1/kic14002/doh16101"
labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005"
saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis"
else:
base_path = "R:\\ENGR_Chon\\Dong\\MATLAB_generate_results\\NIH_PulseWatch"
labels_base_path = "R:\\ENGR_Chon\\NIH_Pulsewatch_Database\\Adjudication_UConn"
saving_base_path = "R:\\ENGR_Chon\\Luis\\Research\\Casseys_case\\Project_1_analysis"

if data_format == 'csv':
data_path = os.path.join(base_path, "TFS_csv")
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm")
saving_path = os.path.join(saving_base_path, "Project_1_analysis")
elif data_format == 'png':
data_path = os.path.join(base_path, "TFS_plots")
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm")
saving_path = os.path.join(saving_base_path, "Project_1_analysis")
else:
raise ValueError("Invalid data format. Choose 'csv' or 'png.")

return data_path, labels_path, saving_path

def load_data(data_path, labels_path, dataset_size=10, train=True, standardize=True, data_format='csv'):
if data_format not in ['csv', 'png']:
raise ValueError("Invalid data_format. Choose 'csv' or 'png'.")

def load_data(data_path, labels_path, dataset_size=10, train=True, standardize=True):
# Load data from the specified data_path
dir_list_UID = os.listdir(data_path)
UID_list = dir_list_UID[:dataset_size] if train else dir_list_UID[dataset_size:]

X_data = []
segment_names = []

for UID in UID_list:
data_path_UID = os.path.join(data_path, UID)
dir_list_seg = os.listdir(data_path_UID)

for seg in dir_list_seg[:50]: # Limiting to 50 segments
seg_path = os.path.join(data_path_UID, seg)
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)

if data_format == 'csv' and seg.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif data_format == 'png' and seg.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
continue # Skip other file formats

X_data.append(time_freq_tensor)
segment_names.append(seg) # Store segment names

X_data = torch.cat(X_data, 0)

if standardize:
X_data = standard_scaling(X_data)

# Extract labels from CSV files
labels = extract_labels(UID_list, labels_path)

return X_data, segment_names, labels

def extract_labels(UID_list, labels_path):
Expand Down Expand Up @@ -269,24 +308,15 @@ def main():
is_linux = False # Set to True if running on Linux, False if on Windows
is_hpc = False # Set to True if running on hpc, False if on Windows

if is_linux:
data_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch/TFS_csv"
labels_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn/final_attemp_4_1_Dong_Ohm"
saving_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis"
elif is_hpc:
data_path = "/gpfs/scratchfs1/kic14002/doh16101/TFS_csv"
labels_path = "/gpfs/scratchfs1/hfp14002/lrm22005/final_attemp_4_1_Dong_Ohm"
saving_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis"
else:
data_path = r"R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch\TFS_csv"
labels_path = r"R:\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm"
saving_path = r"R:\ENGR_Chon\Luis\Research\Casseys_case\Project_1_analysis"
data_format = 'csv' # Choose 'csv' or 'png'

data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc)

train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=20, train=True)
test_data, _, _ = load_data(data_path, labels_path, dataset_size=30, train=False)
train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=20, train=True, data_format=data_format)
# test_data, _, _ = load_data(data_path, labels_path, dataset_size=30, train=False)

train_dataloader = create_dataloader(train_data)
test_dataloader = create_dataloader(test_data)
# test_dataloader = create_dataloader(test_data)

# Visualize random trends/segments
visualize_trends(train_data, segment_names, num_plots=20)
Expand All @@ -297,7 +327,7 @@ def main():

# Visualize the correlation matrix
visualize_correlation_matrix(train_data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=saving_path)
visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=10, save_path=saving_path)
visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=20, save_path=saving_path)
# visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path)

# Perform MFVI for your data
Expand Down

0 comments on commit 6edf0c3

Please sign in to comment.