Skip to content

Commit

Permalink
Merge pull request #1 from lrm22005/Luis
Browse files Browse the repository at this point in the history
Final version 1
  • Loading branch information
lrm22005 committed Oct 24, 2023
2 parents 4b6b50e + 49f69ae commit 61f7a2d
Showing 1 changed file with 292 additions and 28 deletions.
320 changes: 292 additions & 28 deletions project_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,89 @@
import numpy as np
import pandas as pd
import torch
from torch.distributions import MultivariateNormal
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score, adjusted_rand_score
import seaborn as sns
from PIL import Image # Import the Image module

def get_data_paths(data_format, is_linux=False, is_hpc=False):
if is_linux:
base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch"
labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn"
saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis"
elif is_hpc:
base_path = "/gpfs/scratchfs1/kic14002/doh16101"
labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005"
saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis"
else:
base_path = "R:\\ENGR_Chon\\Dong\\MATLAB_generate_results\\NIH_PulseWatch"
labels_base_path = "R:\\ENGR_Chon\\NIH_Pulsewatch_Database\\Adjudication_UConn"
saving_base_path = "R:\\ENGR_Chon\\Luis\\Research\\Casseys_case\\Project_1_analysis"

if data_format == 'csv':
data_path = os.path.join(base_path, "TFS_csv")
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm")
saving_path = os.path.join(saving_base_path, "Project_1_analysis")
elif data_format == 'png':
data_path = os.path.join(base_path, "TFS_plots")
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm")
saving_path = os.path.join(saving_base_path, "Project_1_analysis")
else:
raise ValueError("Invalid data format. Choose 'csv' or 'png.")

return data_path, labels_path, saving_path

def load_data(data_path, labels_path, dataset_size=10, train=True, standardize=True, data_format='csv'):
if data_format not in ['csv', 'png']:
raise ValueError("Invalid data_format. Choose 'csv' or 'png'.")

def load_data(data_path, dataset_size=10, train=True, standardize=True):
# Load data from the specified data_path
dir_list_UID = os.listdir(data_path)
UID_list = dir_list_UID[:dataset_size] if train else dir_list_UID[dataset_size:]

X_data = []
segment_names = []

for UID in UID_list:
data_path_UID = os.path.join(data_path, UID)
dir_list_seg = os.listdir(data_path_UID)

for seg in dir_list_seg[:50]: # Limiting to 50 segments
seg_path = os.path.join(data_path_UID, seg)
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)

if data_format == 'csv' and seg.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif data_format == 'png' and seg.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
continue # Skip other file formats

X_data.append(time_freq_tensor)
segment_names.append(seg) # Store segment names

X_data = torch.cat(X_data, 0)

if standardize:
X_data = standard_scaling(X_data)

return X_data, segment_names

# Extract labels from CSV files
labels = extract_labels(UID_list, labels_path)

return X_data, segment_names, labels

def extract_labels(UID_list, labels_path):
labels = {}
for UID in UID_list:
label_file = os.path.join(labels_path, UID + "_final_attemp_4_1_Dong.csv")
if os.path.exists(label_file):
label_data = pd.read_csv(label_file, sep='\t', header=None, names=['segment', 'label'])
labels[UID] = label_data['label'].values
return labels

def standard_scaling(tensor):
# Z-score normalization (standardization)
Expand All @@ -52,14 +105,18 @@ def create_dataloader(data, batch_size=64, shuffle=True):
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
return data_loader

def visualize_trends(data, segment_names, num_plots=3):
def visualize_trends(data, segment_names, num_plots=3, save_path=None):
# Visualize random trends/segments
num_samples, _, _ = data.shape
for _ in range(num_plots):
idx = np.random.randint(0, num_samples)
plt.figure() # Create a new figure for each plot
plt.imshow(data[idx].numpy())
plt.title(f"Segment: {segment_names[idx]}")
plt.colorbar()
if save_path:
subject_save_path = os.path.join(save_path, f"trends_visualization_segment_{segment_names}.png")
plt.savefig(subject_save_path)
plt.show()

def perform_pca(data, num_components=2):
Expand All @@ -69,36 +126,243 @@ def perform_pca(data, num_components=2):
reduced_data = pca.fit_transform(data_flattened.numpy())
return reduced_data, pca

def visualize_correlation_matrix(data):
# Visualize the correlation matrix
def visualize_correlation_matrix(data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=None):
'''
Usage:
To visualize the correlation matrix for each subject individually, you can call:
visualize_correlation_matrix(train_data, subject_mode=True, save_path="path_to_save_results")
To visualize the correlation matrix for a specific quantity of subjects (groups), you can call:
visualize_correlation_matrix(train_data, subject_mode=False, num_subjects_to_visualize=5, save_path="path_to_save_results")
'''
# Visualize the correlation matrix for each subject or subgroup
data_flattened = data.view(data.size(0), -1).numpy()
correlation_matrix = np.corrcoef(data_flattened, rowvar=False)
sns.heatmap(correlation_matrix, cmap="coolwarm", xticklabels=False, yticklabels=False)
plt.title("Correlation Matrix")

subject_names = [filename.split('_')[0] for filename in segment_names]
unique_subjects = list(set(subject_names))

if subject_mode:
for subject in unique_subjects:
subject_indices = [i for i, name in enumerate(subject_names) if name == subject]
subject_data = data_flattened[subject_indices]
correlation_matrix = np.corrcoef(subject_data, rowvar=False)

plt.figure()
sns.heatmap(correlation_matrix, cmap="coolwarm", xticklabels=False, yticklabels=False)
plt.title(f"Correlation Matrix for Subject {subject}")

if save_path:
subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_{subject}.png")
plt.savefig(subject_save_path)

plt.show()

else: # Group mode
if num_subjects_to_visualize is None:
num_subjects_to_visualize = len(unique_subjects)

for i in range(num_subjects_to_visualize):
subject = unique_subjects[i]
subject_indices = [i for i, name in enumerate(subject_names) if name == subject]
subject_data = data_flattened[subject_indices]
correlation_matrix = np.corrcoef(subject_data, rowvar=False)

plt.figure()
sns.heatmap(correlation_matrix, cmap="coolwarm", xticklabels=False, yticklabels=False)
plt.title(f"Correlation Matrix for {num_subjects_to_visualize} Subjects {subject}")

if save_path:
subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_group_{subject}.png")
plt.savefig(subject_save_path)

plt.show()

# This function computes the log PDF of a multivariate normal distribution
def multivariate_normal_log_pdf_MFVI(x, mu, sigma_sq):
# x: Data points (N x D)
# mu: Means of the components (K x D)
# sigma_sq: Variances of the components (K x D)
N, D = x.shape
K, _ = mu.shape

log_p = torch.empty(N, K, dtype=x.dtype, device=x.device)
for k in range(K):
cov_matrix = torch.diag(sigma_sq[k])
mvn = MultivariateNormal(mu[k], cov_matrix)
log_p[:, k] = mvn.log_prob(x)

return log_p

def perform_mfvi(data, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=True):
N, D = data.shape[0], data.shape[1] * data.shape[2] # Calculate feature dimension D
# Define the variational parameters for the GMM
miu_variational = torch.randn(K, D, requires_grad=True)
log_sigma_variational = torch.randn(K, D, requires_grad=True)
alpha_variational = torch.randn(K, requires_grad=True)

# Define the optimizer for gradient descent
optimizer = torch.optim.Adam([miu_variational, log_sigma_variational, alpha_variational], lr=0.001)

prev_elbo = float('-inf')
iteration = 0
while True:
# Initialize gradients
optimizer.zero_grad()

# Compute the Gaussian means and covariances from variational parameters
sigma_variational_sq = torch.exp(log_sigma_variational.clone())

# Calculate the responsibilities (E[zi])
log_pi_variational = torch.digamma(alpha_variational) - torch.digamma(alpha_variational.sum())
log_resp = log_pi_variational.unsqueeze(0) + multivariate_normal_log_pdf_MFVI(data, miu_variational, sigma_variational_sq)
log_resp_max = log_resp.max(dim=1, keepdim=True).values.clone()
resp = torch.exp(log_resp - log_resp_max).clone()
resp /= resp.sum(dim=1, keepdim=True)

# Compute the ELBO and perform backpropagation
elbo = -torch.sum(resp * log_resp) + torch.sum(resp * torch.log(resp))

# Perform backpropagation with retain_graph=True
elbo.backward(retain_graph=True)

# Update the variational parameters
optimizer.step()

# Print progress
if (iteration + 1) % 100 == 0:
print(f"Iteration {iteration + 1}/{n_optimization_iterations}")

if run_until_convergence:
if iteration > 0 and abs(elbo - prev_elbo) < convergence_threshold:
print(f"Converged after {iteration + 1} iterations")
break
elif iteration == n_optimization_iterations - 1:
print("Reached the specified number of iterations.")

prev_elbo = elbo
iteration += 1

# Extract the learned parameters
miu = miu_variational.detach().numpy()
pi = torch.softmax(alpha_variational, dim=0).detach().numpy()

return miu, pi, resp

def plot_pca(reduced_data, labels, method='original_labels', save_path=None):
"""
Plot the PCA results, and optionally save the plot.
Args:
data (torch.Tensor): The data after perform PCA.
labels (list or np.ndarray): The labels or class information for data.
save_path (str, optional): If provided, save the PCA plot to this path.
Returns:
sklearn.decomposition.PCA: The PCA object containing the results.
"""

# Create a scatter plot of PCA results
plt.figure(figsize=(8, 6))
scatter = plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=labels, cmap=plt.cm.viridis)
plt.colorbar(scatter, label='Labels')
plt.title('PCA Plot {method}')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')

# Save the plot if save_path is provided
if save_path:
plt.savefig(save_path, f"PCA_analysis_using_{method}.png")

plt.show()
# Example usage:
# train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=141, train=True)
# pca = plot_pca(reduced_data, labels, method='original_labels', save_path="pca_plot.png")

def plot_clusters(data, zi, title, save_path=None):
"""
Plot the data points colored by cluster assignment.
Args:
data (torch.Tensor): The data points.
zi (torch.Tensor): The cluster assignments.
title (str): The title for the plot.
"""
unique_clusters = torch.unique(zi)
colors = plt.cm.viridis(torch.linspace(0, 1, len(unique_clusters)))

plt.figure(figsize=(8, 6))
for i, cluster in enumerate(unique_clusters):
cluster_data = data[zi == cluster]
plt.scatter(cluster_data[:, 0], cluster_data[:, 1], c=colors[i], label=f'Cluster {int(cluster)}')

plt.title(title)
plt.legend()
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')

# Save the plot if save_path is provided
if save_path:
filename = title.replace(' ', '_') + ".png"
plt.savefig(os.path.join(save_path, filename))

plt.show()

def main():
is_linux = False # Set to True if running on Linux, False if on Windows
if is_linux:
data_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch/TFS_csv"
else:
data_path = r"R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch\TFS_csv"
is_hpc = False # Set to True if running on hpc, False if on Windows

data_format = 'png' # Choose 'csv' or 'png'

train_data, segment_names = load_data(data_path, dataset_size=141, train=True)
test_data, _ = load_data(data_path, dataset_size=10, train=False)
data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc)

train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=10, train=True, data_format=data_format)
# test_data, _, _ = load_data(data_path, labels_path, dataset_size=30, train=False)

train_dataloader = create_dataloader(train_data)
test_dataloader = create_dataloader(test_data)
# test_dataloader = create_dataloader(test_data)

# Visualize random trends/segments
visualize_trends(train_data, segment_names, num_plots=3)
visualize_trends(train_data, segment_names, num_plots=20)

# Perform PCA for dimensionality reduction
reduced_data, pca = perform_pca(train_data, num_components=2)
print("Explained variance ratio:", pca.explained_variance_ratio_)
# reduced_data, pca = perform_pca(train_data, num_components=2)
# print("Explained variance ratio:", pca.explained_variance_ratio_)

# Visualize the correlation matrix
visualize_correlation_matrix(train_data)
visualize_correlation_matrix(train_data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=saving_path)
# visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path)
# visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path)

# Perform MFVI for your data
K = 4 # Number of clusters
miu_mfvi, pi_mfvi, resp_mfvi = perform_mfvi(train_data, K, n_optimization_iterations=1000, convergence_threshold=1e-5, run_until_convergence=False)

# Calculate clustering metrics for MFVI
zi_mfvi = np.argmax(resp_mfvi, axis=1)
# Perform PCA for dimensionality reduction
reduced_data, pca = perform_pca(train_data, num_components=2)
print("Explained variance ratio:", pca.explained_variance_ratio_)

# Create two plots: PCA results and original labels
plt.figure(figsize=(16, 6))

# Plot PCA results
plot_pca(reduced_data, zi_mfvi, method='MFVI', save_path="pca_plot.png")

# Plot original labels
plot_pca(reduced_data, labels, method="original_labels", save_path="pca_plot.png")

# Calculate clustering metrics for PCA results
silhouette_pca = silhouette_score(reduced_data, zi_mfvi)
ari_pca = adjusted_rand_score(labels, zi_mfvi)

# Print and compare clustering metrics for PCA
print("PCA Clustering Metrics Comparison:")
print(f"Silhouette Score (PCA): {silhouette_pca}")
print(f"Adjusted Rand Index (PCA vs. True Labels): {ari_pca}")

# Plot clusters for MFVI results
plot_clusters(reduced_data, torch.from_numpy(zi_mfvi), title="MFVI Clustering Results (Train Data)", save_path=saving_path)

if __name__ == "__main__":
main()

0 comments on commit 61f7a2d

Please sign in to comment.