Skip to content

NEURAL VARIATIONAL INFERENCE #8

Merged
merged 2 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

error_log.txt
progress.log
163 changes: 55 additions & 108 deletions project_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from sklearn.decomposition import PCA, IncrementalPCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.cluster import MiniBatchKMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score, adjusted_rand_score, adjusted_mutual_info_score, davies_bouldin_score
import seaborn as sns
from PIL import Image # Import the Image module

os.environ['OMP_NUM_THREADS'] = '3'
# Create a logger
logger = logging.getLogger(__name__)
logging.basicConfig(filename='error_log.txt', level=logging.ERROR)
Expand All @@ -35,6 +37,7 @@
progress_logger.addHandler(progress_handler)

def get_data_paths(data_format, is_linux=False, is_hpc=False):
log_progress("Code execution get_data_paths")
if is_linux:
base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch"
labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn"
Expand All @@ -46,8 +49,7 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False):
else:
base_path = "R:\\ENGR_Chon\\Dong\\MATLAB_generate_results\\NIH_PulseWatch"
labels_base_path = "R:\\ENGR_Chon\\NIH_Pulsewatch_Database\\Adjudication_UConn"
saving_base_path = "R:\\ENGR_Chon\\Luis\\Research\\Casseys_case\\Project_1_analysis"

saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case"
if data_format == 'csv':
data_path = os.path.join(base_path, "TFS_csv")
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm")
Expand All @@ -58,16 +60,13 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False):
saving_path = os.path.join(saving_base_path, "Project_1_analysis")
else:
raise ValueError("Invalid data format. Choose 'csv' or 'png.")

log_progress("Code execution completed get_data_paths")
return data_path, labels_path, saving_path

# Standardize the data
def standard_scaling(data):
scaler = StandardScaler()
data_shape = data.shape
data = data.view(data_shape[0], -1)
data = scaler.fit_transform(data)
data = data.view(data_shape)
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)
return torch.Tensor(data)

def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=True, data_format='csv', return_all=False):
Expand Down Expand Up @@ -108,7 +107,6 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr
logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}")
logger.error(f"Error processing segment: {time_freq_plot.size()} in UID: {UID}. Exception: {str(e)}")
# You can also add more information to the error log, such as the value of time_freq_plot.
continue # Continue to the next segment

X_data = torch.cat(X_data, 0)
X_data_original = torch.cat(X_data_original, 0)
Expand Down Expand Up @@ -141,18 +139,24 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr

# Filter out segments that are unlabeled (-1)
filtered_segment_names = [segment_name for segment_name, label in segment_labels.items() if label != -1]

# Filter data to match the filtered segment names
filtered_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in filtered_segment_names])
# Check if there are no labeled segments
if not filtered_segment_names:
filtered_data = None # Set filtered_data to None
else:
# Filter data to match the filtered segment names
filtered_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in filtered_segment_names])

# Return all segments along with labels
if return_all is True:
return X_data_original, X_data, segment_names, segment_labels, segment_labels.values()

# Return labeled and unlabeled segments along with labels
if return_all == 'labeled':
return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values()

if filtered_data is not None:
return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values()
else:
return X_data_original, None, [], {}, []

# Return unlabeled segments along with labels
elif return_all == 'unlabeled':
unlabeled_segment_names = [segment_name for segment_name, label in segment_labels.items() if label == -1]
Expand Down Expand Up @@ -207,8 +211,8 @@ def visualize_trends(standardized_data, original_data, segment_names, num_plots=
plt.tight_layout()

if save_path:
subject_save_path = os.path.join(save_path, f"trends_visualization_segment_{segment_names[idx]}.png")
plt.savefig(subject_save_path)
subject_save_path = os.path.join(save_path, f"trends_{segment_names[idx]}.png")
plt.savefig(subject_save_path, dpi=400, format='png')
plt.show()
elif data_format == 'png':
print("This is a trend analysis for PNG data format.")
Expand All @@ -233,7 +237,7 @@ def perform_pca_sgd(data, num_components=2, num_clusters=4, batch_size=64):
reduced_data = ipca.fit_transform(data_flattened.numpy())

# Cluster the data using K-Means
kmeans = KMeans(n_clusters=num_clusters)
kmeans = KMeans(n_clusters=num_clusters, init='k-means++', n_init='auto')
labels = kmeans.fit_predict(reduced_data)

return reduced_data, ipca, labels
Expand All @@ -247,98 +251,36 @@ def perform_tsne(data, num_components=2, num_clusters=4):
reduced_data = tsne.fit_transform(data_flattened.numpy())

# Cluster the data using K-Means
kmeans = KMeans(n_clusters=num_clusters)
kmeans = KMeans(n_clusters=num_clusters, init='k-means++', n_init='auto')
labels = kmeans.fit_predict(reduced_data)

return reduced_data, labels

def visualize_correlation_matrix(data, segment_names, subject_mode=True, num_subjects_to_visualize=None, batch_size=32, method='pearson', save_path=None):
data_flattened = data.view(data.size(0), -1).numpy()

if subject_mode:
subject_names = [filename.split('_')[0] for filename in segment_names]
unique_subjects = list(set(subject_names))
else:
subject_names = [filename.split('_')[0] for filename in segment_names]
unique_subjects = list(set(subject_names))

if num_subjects_to_visualize is None:
num_subjects_to_visualize = len(unique_subjects)

for i in range(num_subjects_to_visualize):
subject = unique_subjects[i]
subject_indices = [j for j, name in enumerate(subject_names) if name == subject]
subject_data = data_flattened[subject_indices]

# Shuffle the data to avoid bias
np.random.shuffle(subject_data)

# Calculate the number of batches
num_batches = len(subject_data) // batch_size

batch_correlations = []

for batch_index in range(num_batches):
start = batch_index * batch_size
end = (batch_index + 1) * batch_size
batch = subject_data[start:end]

# Calculate the correlation matrix for the batch
correlation_matrix = np.corrcoef(batch, rowvar=False)

# Calculate the mean or median of the per-batch correlations
if method == 'mean':
batch_correlation = np.mean(correlation_matrix)
elif method == 'median':
batch_correlation = np.median(correlation_matrix)

batch_correlations.append(batch_correlation)

# Aggregate the batch correlations
overall_correlation = np.mean(batch_correlations) # You can use median instead of mean if needed

# Calculate confidence intervals on the aggregated correlation
batch_correlations = np.array(batch_correlations)
ci_lower = np.percentile(batch_correlations, 2.5)
ci_upper = np.percentile(batch_correlations, 97.5)

# Print or save the results
print(f"Overall Correlation for {num_subjects_to_visualize} Subjects {subject}: {overall_correlation:.4f}")
print(f"Confidence Intervals: [{ci_lower:.4f}, {ci_upper:.4f}]")

if save_path:
subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_group_{subject}.png")
plt.figure()
sns.heatmap(correlation_matrix, cmap="coolwarm", xticklabels=False, yticklabels=False)
plt.title(f"Correlation Matrix for {num_subjects_to_visualize} Subjects {subject}")
plt.savefig(subject_save_path)
plt.close()

# Plot the per-batch correlations over time/batches
plt.figure()
plt.plot(batch_correlations)
plt.xlabel("Batch Index")
plt.ylabel("Correlation")
plt.title(f"Per-Batch Correlations for {num_subjects_to_visualize} Subjects {subject}")
plt.show()

return reduced_data, labels

# This function computes the log PDF of a multivariate normal distribution
def multivariate_normal_log_pdf_MFVI(x, mu, sigma_sq):
# x: Data points (N x D)
# x: Data points (N x H x W)
# mu: Means of the components (K x D)
# sigma_sq: Variances of the components (K x D)
N, D = x.shape
K, _ = mu.shape

N, H, W = x.shape # Get the dimensions of the data tensor
K, D = mu.shape

log_p = torch.empty(N, K, dtype=x.dtype, device=x.device)
for k in range(K):
# Create a covariance matrix for each component
cov_matrix = torch.diag(sigma_sq[k])
mvn = MultivariateNormal(mu[k], cov_matrix)
log_p[:, k] = mvn.log_prob(x)

# Calculate the log PDF for each data point
for n in range(N):
data_point = x[n].view(-1) # Flatten the 2D slice to a 1D vector
mvn = MultivariateNormal(mu[k], cov_matrix)
log_p[n, k] = mvn.log_prob(data_point)

return log_p

def perform_mfvi(data, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=True):
import gc # Import the garbage collection module

def perform_mfvi(data, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=False):
N, D = data.shape[0], data.shape[1] * data.shape[2] # Calculate feature dimension D
# Define the variational parameters for the GMM
miu_variational = torch.randn(K, D, requires_grad=True)
Expand Down Expand Up @@ -386,7 +328,11 @@ def perform_mfvi(data, K, n_optimization_iterations, convergence_threshold=1e-5,

prev_elbo = elbo
iteration += 1


# Clean up CPU memory using garbage collection
gc.collect()

print(f"Iteration {iteration}/{n_optimization_iterations}")
# Extract the learned parameters
miu = miu_variational.detach().numpy()
pi = torch.softmax(alpha_variational, dim=0).detach().numpy()
Expand Down Expand Up @@ -485,7 +431,7 @@ def perform_dimensionality_reduction(data, data_type, saving_path):

def perform_clustering(data, labels, pca_reduced_data, pca_labels, tsne_reduced_data, tsne_labels, mfvi_labels, data_type, saving_path):
log_progress(f"Performing clustering for {data_type} data")

print('starting clustering')
if data_type == "labeled":
method ='PCA on Labeled Data'
title_tsne = 't-SNE on Labeled Data'
Expand All @@ -508,38 +454,44 @@ def perform_clustering(data, labels, pca_reduced_data, pca_labels, tsne_reduced_
plot_pca(pca_reduced_data, labels, method, save_path=saving_path)
except Exception as e:
logger.error(f"An error occurred while plotting PCA: {str(e)}")
print('Performed PCA on the data')

try:
# Evaluate clustering for PCA results
ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca = evaluate_clustering(data.view(data.size(0), -1).numpy(), labels, pca_labels)
except Exception as e:
logger.error(f"An error occurred while clustering PCA results: {str(e)}")
print('Evaluated clustering for PCA results')

# Perform t-SNE on the data
try:
# Plot t-SNE for the data
plot_clusters(tsne_reduced_data, tsne_labels, title_tsne, save_path=saving_path)
except Exception as e:
logger.error(f"An error occurred while plotting t-SNE: {str(e)}")

print('Performed t-SNE on the data')

try:
# Evaluate clustering for t-SNE results
ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne = evaluate_clustering(tsne_reduced_data, labels, tsne_labels)
except Exception as e:
logger.error(f"An error occurred while clustering t-SNE results: {str(e)}")

print('Evaluated clustering for t-SNE results')

try:
# Plot MFVI for data
plot_clusters(data.view(data.size(0), -1).numpy(), mfvi_labels, title_mfvi, save_path=saving_path)
except Exception as e:
logger.error(f"An error occurred while plotting Mean-Field Variational Inference in labeled data: {str(e)}")
print('Plotted MFVI for data')

try:
# For MFVI on data
ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi = evaluate_clustering(data.view(data.size(0), -1).numpy(), labels, mfvi_labels)
except Exception as e:
logger.error(f"An error occurred while clustering Mean-Field Variational Inference results in labeled data: {str(e)}")

print('Evaluated MFVI on data')

return ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca, ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne, ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi

def visualize_and_analyze_data(data, original_data, segments, labels, data_type, saving_path, data_format):
Expand All @@ -550,19 +502,14 @@ def visualize_and_analyze_data(data, original_data, segments, labels, data_type,
except Exception as e:
handle_error(f"Visualizing trends for {data_type} data", e)

try:
visualize_correlation_matrix(data, segments, subject_mode=False, num_subjects_to_visualize=None, batch_size=32, method='pearson', save_path=saving_path)
except Exception as e:
handle_error(f"Visualizing correlation matrix for {data_type} data", e)

try:
pca_reduced_data, pca_labels, tsne_reduced_data, tsne_labels = perform_dimensionality_reduction(data, data_type, saving_path)
except Exception as e:
handle_error(f"Performing dimensionality reduction for {data_type} data", e)

try:
# Perform MFVI on data
miu, pi, resp = perform_mfvi(data, K=4, n_optimization_iterations=300, convergence_threshold=1e-5, run_until_convergence=False)
miu, pi, resp = perform_mfvi(data, K=4, n_optimization_iterations=10, convergence_threshold=1e-5, run_until_convergence=False)

# Extract cluster assignments from MFVI
mfvi_labels = torch.argmax(resp, dim=1).numpy()
Expand Down Expand Up @@ -681,7 +628,7 @@ def main(case_to_run):
unlabeled_original_data, unlabeled_data, unlabeled_segments, unlabeled_segments_labels, unlabeled_labels = load_data(data_path, labels_path, dataset_size=10, train=True, data_format=data_format, return_all="unlabeled")
results['unlabeled'] = process_unlabeled_data(unlabeled_data, unlabeled_original_data, unlabeled_segments, unlabeled_labels, saving_path, data_format)
elif case_to_run == "all_data":
all_original_data, all_data, all_segment_names, segment_labels, all_labels = load_data(data_path, labels_path, dataset_size=10, train=True, data_format=data_format, return_all=True)
all_original_data, all_data, all_segment_names, segment_labels, all_labels = load_data(data_path, labels_path, dataset_size=1, train=True, data_format=data_format, return_all=True)
results['all_data'] = process_all_data(all_data, all_original_data, all_segment_names, all_labels, saving_path, data_format)
else:
log_progress("Invalid case specified. Please use 'labeled', 'unlabeled', or 'all_data'.")
Expand All @@ -698,5 +645,5 @@ def handle_error(task, error):

if __name__ == "__main__":
# Specify the case you want to run: 'labeled', 'unlabeled', or 'all_data'
case_to_run = "labeled"
case_to_run = "all_data"
main(case_to_run)
Loading