Skip to content

Commit

Permalink
Adding the MFVI and PCA
Browse files Browse the repository at this point in the history
The Mean Field Variance Inference method was added to this code, performing this analysis looking for a probabilistic model to compare the distributions of the data. I would update in a new version a convergence criteria to evaluate in a more complex and significant relevant way the data.
  • Loading branch information
lrm22005 committed Oct 24, 2023
1 parent 4b6b50e commit 5fc6745
Showing 1 changed file with 233 additions and 14 deletions.
247 changes: 233 additions & 14 deletions project_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import numpy as np
import pandas as pd
import torch
from torch.distributions import MultivariateNormal
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.decomposition import PCA, silhouette_score, adjusted_rand_score
import seaborn as sns

def load_data(data_path, dataset_size=10, train=True, standardize=True):
def load_data(data_path, labels_path, dataset_size=10, train=True, standardize=True):
# Load data from the specified data_path
dir_list_UID = os.listdir(data_path)
UID_list = dir_list_UID[:dataset_size] if train else dir_list_UID[dataset_size:]
Expand All @@ -38,7 +39,19 @@ def load_data(data_path, dataset_size=10, train=True, standardize=True):
if standardize:
X_data = standard_scaling(X_data)

return X_data, segment_names
# Extract labels from CSV files
labels = extract_labels(UID_list, labels_path)

return X_data, segment_names, labels

def extract_labels(UID_list, labels_path):
labels = {}
for UID in UID_list:
label_file = os.path.join(labels_path, UID + "_final_attemp_4_1_Dong.csv")
if os.path.exists(label_file):
label_data = pd.read_csv(label_file, sep='\t', header=None, names=['segment', 'label'])
labels[UID] = label_data['label'].values
return labels

def standard_scaling(tensor):
# Z-score normalization (standardization)
Expand All @@ -57,6 +70,7 @@ def visualize_trends(data, segment_names, num_plots=3):
num_samples, _, _ = data.shape
for _ in range(num_plots):
idx = np.random.randint(0, num_samples)
plt.figure() # Create a new figure for each plot
plt.imshow(data[idx].numpy())
plt.title(f"Segment: {segment_names[idx]}")
plt.colorbar()
Expand All @@ -69,36 +83,241 @@ def perform_pca(data, num_components=2):
reduced_data = pca.fit_transform(data_flattened.numpy())
return reduced_data, pca

def visualize_correlation_matrix(data):
# Visualize the correlation matrix
def visualize_correlation_matrix(data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=None):
'''
Usage:
To visualize the correlation matrix for each subject individually, you can call:
visualize_correlation_matrix(train_data, subject_mode=True, save_path="path_to_save_results")
To visualize the correlation matrix for a specific quantity of subjects (groups), you can call:
visualize_correlation_matrix(train_data, subject_mode=False, num_subjects_to_visualize=5, save_path="path_to_save_results")
'''
# Visualize the correlation matrix for each subject or subgroup
data_flattened = data.view(data.size(0), -1).numpy()
correlation_matrix = np.corrcoef(data_flattened, rowvar=False)
sns.heatmap(correlation_matrix, cmap="coolwarm", xticklabels=False, yticklabels=False)
plt.title("Correlation Matrix")

subject_names = [filename.split('_')[0] for filename in segment_names]
unique_subjects = list(set(subject_names))

if subject_mode:
for subject in unique_subjects:
subject_indices = [i for i, name in enumerate(subject_names) if name == subject]
subject_data = data_flattened[subject_indices]
correlation_matrix = np.corrcoef(subject_data, rowvar=False)

plt.figure()
sns.heatmap(correlation_matrix, cmap="coolwarm", xticklabels=False, yticklabels=False)
plt.title(f"Correlation Matrix for Subject {subject}")

if save_path:
subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_{subject}.png")
plt.savefig(subject_save_path)

plt.show()

else: # Group mode
if num_subjects_to_visualize is None:
num_subjects_to_visualize = len(unique_subjects)

for i in range(num_subjects_to_visualize):
subject = unique_subjects[i]
subject_indices = [i for i, name in enumerate(subject_names) if name == subject]
subject_data = data_flattened[subject_indices]
correlation_matrix = np.corrcoef(subject_data, rowvar=False)

plt.figure()
sns.heatmap(correlation_matrix, cmap="coolwarm", xticklabels=False, yticklabels=False)
plt.title(f"Correlation Matrix for {num_subjects_to_visualize} Subjects {subject}")

if save_path:
subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_group{subject}.png")
plt.savefig(subject_save_path)

plt.show()

# This function computes the log PDF of a multivariate normal distribution
def multivariate_normal_log_pdf_MFVI(x, mu, sigma_sq):
# x: Data points (N x D)
# mu: Means of the components (K x D)
# sigma_sq: Variances of the components (K x D)
N, D = x.shape
K, _ = mu.shape

log_p = torch.empty(N, K, dtype=x.dtype, device=x.device)
for k in range(K):
cov_matrix = torch.diag(sigma_sq[k])
mvn = MultivariateNormal(mu[k], cov_matrix)
log_p[:, k] = mvn.log_prob(x)

return log_p

def perform_mfvi(data, K, n_optimization_iterations):
N, D = data.shape[0], data.shape[1] * data.shape[2] # Calculate feature dimension D
# Define the variational parameters for the GMM
miu_variational = torch.randn(K, D, requires_grad=True)
log_sigma_variational = torch.randn(K, D, requires_grad=True)
alpha_variational = torch.randn(K, requires_grad=True)

# Define the optimizer for gradient descent
optimizer = torch.optim.Adam([miu_variational, log_sigma_variational, alpha_variational], lr=0.001)

for iteration in range(n_optimization_iterations):
# Initialize gradients
optimizer.zero_grad()

# Compute the Gaussian means and covariances from variational parameters
sigma_variational_sq = torch.exp(log_sigma_variational.clone())

# Calculate the responsibilities (E[zi])
log_pi_variational = torch.digamma(alpha_variational) - torch.digamma(alpha_variational.sum())
log_resp = log_pi_variational.unsqueeze(0) + multivariate_normal_log_pdf_MFVI(data, miu_variational, sigma_variational_sq)
log_resp_max = log_resp.max(dim=1, keepdim=True).values.clone()
resp = torch.exp(log_resp - log_resp_max).clone()
resp /= resp.sum(dim=1, keepdim=True)

# Compute the ELBO and perform backpropagation
elbo = -torch.sum(resp * log_resp) + torch.sum(resp * torch.log(resp))

# Perform backpropagation with retain_graph=True
elbo.backward(retain_graph=True)

# Update the variational parameters
optimizer.step()

# Print progress
if (iteration + 1) % 100 == 0:
print(f"Iteration {iteration + 1}/{n_optimization_iterations}")

# Extract the learned parameters
miu = miu_variational.detach().numpy()
pi = torch.softmax(alpha_variational, dim=0).detach().numpy()

return miu, pi, resp

def plot_pca(reduced_data, labels, method='original_labels', save_path=None):
"""
Plot the PCA results, and optionally save the plot.
Args:
data (torch.Tensor): The data after perform PCA.
labels (list or np.ndarray): The labels or class information for data.
save_path (str, optional): If provided, save the PCA plot to this path.
Returns:
sklearn.decomposition.PCA: The PCA object containing the results.
"""

# Create a scatter plot of PCA results
plt.figure(figsize=(8, 6))
scatter = plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=labels, cmap=plt.cm.viridis)
plt.colorbar(scatter, label='Labels')
plt.title('PCA Plot {method}')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')

# Save the plot if save_path is provided
if save_path:
plt.savefig(save_path, f"PCA_analysis_using_{method}.png")

plt.show()
# Example usage:
# train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=141, train=True)
# pca = plot_pca(reduced_data, labels, method='original_labels', save_path="pca_plot.png")

def plot_clusters(data, zi, title, save_path=None):
"""
Plot the data points colored by cluster assignment.
Args:
data (torch.Tensor): The data points.
zi (torch.Tensor): The cluster assignments.
title (str): The title for the plot.
"""
unique_clusters = torch.unique(zi)
colors = plt.cm.viridis(torch.linspace(0, 1, len(unique_clusters)))

plt.figure(figsize=(8, 6))
for i, cluster in enumerate(unique_clusters):
cluster_data = data[zi == cluster]
plt.scatter(cluster_data[:, 0], cluster_data[:, 1], c=colors[i], label=f'Cluster {int(cluster)}')

plt.title(title)
plt.legend()
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')

# Save the plot if save_path is provided
if save_path:
filename = title.replace(' ', '_') + ".png"
plt.savefig(os.path.join(save_path, filename))

plt.show()

def main():
is_linux = False # Set to True if running on Linux, False if on Windows
is_hpc = False # Set to True if running on hpc, False if on Windows

if is_linux:
data_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch/TFS_csv"
labels_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn/final_attemp_4_1_Dong_Ohm"
saving_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis"
elif is_hpc:
data_path = "/gpfs/scratchfs1/kic14002/doh16101/TFS_csv"
labels_path = "/gpfs/scratchfs1/hfp14002/lrm22005/final_attemp_4_1_Dong_Ohm"
saving_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis"
else:
data_path = r"R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch\TFS_csv"
labels_path = r"R:\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm"
saving_path = r"R:\ENGR_Chon\Luis\Research\Casseys_case\Project_1_analysis"

train_data, segment_names = load_data(data_path, dataset_size=141, train=True)
test_data, _ = load_data(data_path, dataset_size=10, train=False)
train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=141, train=True)
test_data, _, _ = load_data(data_path, labels_path, dataset_size=141, train=False)

train_dataloader = create_dataloader(train_data)
test_dataloader = create_dataloader(test_data)

# Visualize random trends/segments
visualize_trends(train_data, segment_names, num_plots=3)
visualize_trends(train_data, segment_names, num_plots=20)

# Perform PCA for dimensionality reduction
reduced_data, pca = perform_pca(train_data, num_components=2)
print("Explained variance ratio:", pca.explained_variance_ratio_)
# reduced_data, pca = perform_pca(train_data, num_components=2)
# print("Explained variance ratio:", pca.explained_variance_ratio_)

# Visualize the correlation matrix
visualize_correlation_matrix(train_data)
visualize_correlation_matrix(train_data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=saving_path)
visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=10, save_path=saving_path)
# visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path)

# Perform MFVI for your data
K = 4 # Number of clusters
n_optimization_iterations = 100
miu_mfvi, pi_mfvi, resp_mfvi = perform_mfvi(train_data, K, n_optimization_iterations)

# Calculate clustering metrics for MFVI
zi_mfvi = np.argmax(resp_mfvi, axis=1)
# Perform PCA for dimensionality reduction
reduced_data, pca = perform_pca(train_data, num_components=2)
print("Explained variance ratio:", pca.explained_variance_ratio_)

# Create two plots: PCA results and original labels
plt.figure(figsize=(16, 6))

# Plot PCA results
plot_pca(reduced_data, zi_mfvi, method='MFVI', save_path="pca_plot.png")

# Plot original labels
plot_pca(reduced_data, labels, method="original_labels", save_path="pca_plot.png")

# Calculate clustering metrics for PCA results
silhouette_pca = silhouette_score(reduced_data, zi_mfvi)
ari_pca = adjusted_rand_score(labels, zi_mfvi)

# Print and compare clustering metrics for PCA
print("PCA Clustering Metrics Comparison:")
print(f"Silhouette Score (PCA): {silhouette_pca}")
print(f"Adjusted Rand Index (PCA vs. True Labels): {ari_pca}")

# Plot clusters for MFVI results
plot_clusters(reduced_data, torch.from_numpy(zi_mfvi), title="MFVI Clustering Results (Train Data)", save_path=saving_path)

if __name__ == "__main__":
main()

0 comments on commit 5fc6745

Please sign in to comment.