From b3b4b4d963aa87a845ac2cd3add2135144ad8258 Mon Sep 17 00:00:00 2001 From: Luis Roberto Mercado Diaz Date: Mon, 27 Nov 2023 12:28:57 -0500 Subject: [PATCH] Update the model structure and goals I just generate a new method for the evaluation and loading of the data, a new method to process and the variational method, just trying to simplify the understanding and looking for a more confidence and stable method for the results Co-Authored-By: Dong Han --- project_1.py | 28 +- project_2.py | 525 +++++++++++++++ project_NVIL.py | 1643 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 2188 insertions(+), 8 deletions(-) create mode 100644 project_2.py create mode 100644 project_NVIL.py diff --git a/project_1.py b/project_1.py index e86a906..9758139 100644 --- a/project_1.py +++ b/project_1.py @@ -278,7 +278,9 @@ def multivariate_normal_log_pdf_MFVI(x, mu, sigma_sq): return log_p -def perform_mfvi(data, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=True): +import gc # Import the garbage collection module + +def perform_mfvi(data, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=False): N, D = data.shape[0], data.shape[1] * data.shape[2] # Calculate feature dimension D # Define the variational parameters for the GMM miu_variational = torch.randn(K, D, requires_grad=True) @@ -326,7 +328,11 @@ def perform_mfvi(data, K, n_optimization_iterations, convergence_threshold=1e-5, prev_elbo = elbo iteration += 1 - + + # Clean up CPU memory using garbage collection + gc.collect() + + print(f"Iteration {iteration}/{n_optimization_iterations}") # Extract the learned parameters miu = miu_variational.detach().numpy() pi = torch.softmax(alpha_variational, dim=0).detach().numpy() @@ -425,7 +431,7 @@ def perform_dimensionality_reduction(data, data_type, saving_path): def perform_clustering(data, labels, pca_reduced_data, pca_labels, tsne_reduced_data, tsne_labels, mfvi_labels, data_type, saving_path): log_progress(f"Performing clustering for {data_type} data") - + print('starting clustering') if data_type == "labeled": method ='PCA on Labeled Data' title_tsne = 't-SNE on Labeled Data' @@ -448,12 +454,14 @@ def perform_clustering(data, labels, pca_reduced_data, pca_labels, tsne_reduced_ plot_pca(pca_reduced_data, labels, method, save_path=saving_path) except Exception as e: logger.error(f"An error occurred while plotting PCA: {str(e)}") + print('Performed PCA on the data') try: # Evaluate clustering for PCA results ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca = evaluate_clustering(data.view(data.size(0), -1).numpy(), labels, pca_labels) except Exception as e: logger.error(f"An error occurred while clustering PCA results: {str(e)}") + print('Evaluated clustering for PCA results') # Perform t-SNE on the data try: @@ -461,25 +469,29 @@ def perform_clustering(data, labels, pca_reduced_data, pca_labels, tsne_reduced_ plot_clusters(tsne_reduced_data, tsne_labels, title_tsne, save_path=saving_path) except Exception as e: logger.error(f"An error occurred while plotting t-SNE: {str(e)}") - + print('Performed t-SNE on the data') + try: # Evaluate clustering for t-SNE results ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne = evaluate_clustering(tsne_reduced_data, labels, tsne_labels) except Exception as e: logger.error(f"An error occurred while clustering t-SNE results: {str(e)}") - + print('Evaluated clustering for t-SNE results') + try: # Plot MFVI for data plot_clusters(data.view(data.size(0), -1).numpy(), mfvi_labels, title_mfvi, save_path=saving_path) except Exception as e: logger.error(f"An error occurred while plotting Mean-Field Variational Inference in labeled data: {str(e)}") + print('Plotted MFVI for data') try: # For MFVI on data ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi = evaluate_clustering(data.view(data.size(0), -1).numpy(), labels, mfvi_labels) except Exception as e: logger.error(f"An error occurred while clustering Mean-Field Variational Inference results in labeled data: {str(e)}") - + print('Evaluated MFVI on data') + return ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca, ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne, ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi def visualize_and_analyze_data(data, original_data, segments, labels, data_type, saving_path, data_format): @@ -497,7 +509,7 @@ def visualize_and_analyze_data(data, original_data, segments, labels, data_type, try: # Perform MFVI on data - miu, pi, resp = perform_mfvi(data, K=4, n_optimization_iterations=300, convergence_threshold=1e-5, run_until_convergence=False) + miu, pi, resp = perform_mfvi(data, K=4, n_optimization_iterations=10, convergence_threshold=1e-5, run_until_convergence=False) # Extract cluster assignments from MFVI mfvi_labels = torch.argmax(resp, dim=1).numpy() @@ -633,5 +645,5 @@ def handle_error(task, error): if __name__ == "__main__": # Specify the case you want to run: 'labeled', 'unlabeled', or 'all_data' - case_to_run = "labeled" + case_to_run = "all_data" main(case_to_run) \ No newline at end of file diff --git a/project_2.py b/project_2.py new file mode 100644 index 0000000..e25fe0a --- /dev/null +++ b/project_2.py @@ -0,0 +1,525 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Nov 18 23:01:01 2023 + +@author: lrm22005 + +THE FIST PART OF THIS CODE IS BASED ON DONG HANG CODE TO SPLIT THE DATA IN TRAINING, TEST AND UNLABELED. + + +""" +import pandas as pd + +############################################################################################################################################## +############################################################################################################################################## +############################################################################################################################################## +# ====== Load the per subject arrythmia summary ====== +df_summary = pd.read_csv(r'R:\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv') +df_summary['UID'] = df_summary['UID'].astype(str).str.zfill(3) + +df_summary['sample_nonAF'] = df_summary['NSR'] + df_summary['PACPVC'] + df_summary['SVT'] +df_summary['sample_AF'] = df_summary['AF'] + +df_summary['sample_nonAF_ratio'] = df_summary['sample_nonAF'] / (df_summary['sample_AF'] + df_summary['sample_nonAF']) + +all_UIDs = df_summary['UID'].unique() +# ==================================================== +# ====== AF trial separation ====== +# R:\ENGR_Chon\Dong\Numbers\Pulsewatch_numbers\Fahimeh_CNNED_general_ExpertSystemwApplication\tbl_file_name\TrainingSet_final_segments +AF_trial_Fahimeh_train = ['402','410'] +AF_trial_Fahimeh_test = ['301', '302', '305', '306', '307', '310', '311', + '312', '318', '319', '320', '321', '322', '324', + '325', '327', '329', '400', '406', '407', '409', + '414'] +AF_trial_Fahimeh_did_not_use = ['405', '413', '415', '416', '420', '421', '422', '423'] +AF_trial_paroxysmal_AF = ['408','419'] + +AF_trial_train = AF_trial_Fahimeh_train +AF_trial_test = AF_trial_Fahimeh_test +AF_trial_unlabeled = AF_trial_Fahimeh_did_not_use + AF_trial_paroxysmal_AF +print(f'AF trial: {len(AF_trial_train)} training subjects {AF_trial_train}') +print(f'AF trial: {len(AF_trial_test)} testing subjects {AF_trial_test}') +print(f'AF trial: {len(AF_trial_unlabeled)} unlabeled subjects {AF_trial_unlabeled}') +# ================================= +# === Clinical trial AF subjects separation === +clinical_trial_AF_subjects = ['005', '017', '026', '051', '075', '082'] + +remaining_UIDs = [] +count_NSR = [] +import math +for index, row in df_summary.iterrows(): + UID = row['UID'] + this_NSR = row['sample_nonAF'] + if math.isnan(this_NSR): + # There is no segment in this subject, skip this UID. + print(f'---------UID {UID} has no segments.------------') + continue + if UID not in AF_trial_train and UID not in AF_trial_test and UID not in clinical_trial_AF_subjects \ + and not UID[0] == '3' and not UID[0] == '4': + remaining_UIDs.append(UID) + count_NSR.append(this_NSR) + +from numpy import random +random.seed(seed=42) +from numpy.random import choice +list_of_candidates = remaining_UIDs +number_of_items_to_pick = round(len(list_of_candidates) * 0.15) # 10% labeled for training, 5% for testing. +temp_sum = sum(count_NSR) +probability_distribution = [x/temp_sum for x in count_NSR] +probability_distribution = [(1-x/temp_sum)/ (len(count_NSR)-1) for x in count_NSR]# Subjects with fewer segments have higher chance to be selected. Make sure the sum is one. +draw = choice(list_of_candidates, number_of_items_to_pick, + p=probability_distribution, replace=False) + +clinical_trial_train = list(draw[:round(len(list_of_candidates) * 0.1)]) +clinical_trial_test_nonAF = list(draw[round(len(list_of_candidates) * 0.1):]) +clinical_trial_test_temp = clinical_trial_test_nonAF + clinical_trial_AF_subjects +clinical_trial_test = [] +for UID in clinical_trial_test_temp: + # UID 051 and maybe other UIDs had no segments (unknown reason). + if UID in all_UIDs: + clinical_trial_test.append(UID) + +clinical_trial_unlabeled = [] +for UID in all_UIDs: + if UID not in clinical_trial_train and UID not in clinical_trial_test and not UID[0] == '3' and not UID[0] == '4': + clinical_trial_unlabeled.append(UID) +print(f'Clinical trial: selected {len(clinical_trial_train)} UIDs for training {clinical_trial_train}') +print(f'Clinical trial: selected {len(clinical_trial_test)} UIDs for testing {clinical_trial_test}') +print(f'Clinical trial: selected {len(clinical_trial_unlabeled)} UIDs for unlabeled {clinical_trial_unlabeled}') + + +############################################################################################################################################## +############################################################################################################################################## +############################################################################################################################################## + +import os +import logging +import numpy as np +import pandas as pd +import torch +from torch.distributions import MultivariateNormal +from torch.utils.data import DataLoader, TensorDataset +import matplotlib.pyplot as plt +from sklearn.decomposition import PCA, IncrementalPCA +from sklearn.manifold import TSNE +from sklearn.cluster import KMeans +from sklearn.cluster import MiniBatchKMeans +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import silhouette_score, adjusted_rand_score, adjusted_mutual_info_score, davies_bouldin_score +import seaborn as sns +from PIL import Image # Import the Image module + +os.environ['OMP_NUM_THREADS'] = '3' +# Create a logger +logger = logging.getLogger(__name__) +logging.basicConfig(filename='error_log.txt', level=logging.ERROR) +logging.basicConfig(level=logging.INFO) + +# Create a logger for progress monitoring +progress_logger = logging.getLogger('progress') +progress_logger.setLevel(logging.INFO) +progress_handler = logging.FileHandler('progress.log') +progress_formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s') +progress_handler.setFormatter(progress_formatter) +progress_logger.addHandler(progress_handler) + +def log_progress(message): + progress_logger.info(message) + +def get_data_paths(data_format, is_linux=False, is_hpc=False): + log_progress("Code execution get_data_paths") + if is_linux: + base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch" + labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn" + saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis" + elif is_hpc: + base_path = "/gpfs/scratchfs1/kic14002/doh16101" + labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005" + saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" + else: + base_path = "R:\\ENGR_Chon\\Dong\\MATLAB_generate_results\\NIH_PulseWatch" + labels_base_path = "R:\\ENGR_Chon\\NIH_Pulsewatch_Database\\Adjudication_UConn" + saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case" + if data_format == 'csv': + data_path = os.path.join(base_path, "TFS_csv") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + elif data_format == 'png': + data_path = os.path.join(base_path, "TFS_plots") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + else: + raise ValueError("Invalid data format. Choose 'csv' or 'png.") + log_progress("Code execution completed get_data_paths") + return data_path, labels_path, saving_path + +def load_data_split(data_path, labels_path, UIDs, standardize=True, data_format='csv'): + if data_format not in ['csv', 'png']: + raise ValueError("Invalid data_format. Choose 'csv' or 'png.") + + data, labels, segment_names = load_split_data(data_path, labels_path, UIDs, standardize, data_format) + + return data, labels, segment_names + +def extract_labels(UID_list, labels_path, segment_names): + labels = {} + for UID in UID_list: + label_file = os.path.join(labels_path, UID + "_final_attemp_4_1_Dong.csv") + if os.path.exists(label_file): + label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) + label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) + labels[UID] = (label_data['label'].values, label_segment_names.values) + + return labels + +# Standardize the data +def standard_scaling(data): + scaler = StandardScaler() + data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) + return torch.Tensor(data) + +def load_split_data(data_path, labels_path, UIDs, standardize=True, data_format='csv'): + X_data_original = [] # Store original data without standardization + segment_names = [] + + for UID in UIDs: + data_path_UID = os.path.join(data_path, UID) + dir_list_seg = os.listdir(data_path_UID) + + # for seg in dir_list_seg[:len(dir_list_seg)]: # Limiting to 50 segments + for seg in dir_list_seg[:500]: # Limiting to the first 50 segments + seg_path = os.path.join(data_path_UID, seg) + + try: + if data_format == 'csv' and seg.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif data_format == 'png' and seg.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + continue # Skip other file formats + + X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data + + # Extract and store segment names + # Change here: Use the segment name from the CSV file directly + segment_names.append(seg.split('_filt')[0]) + + except Exception as e: + logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}") + # You can also add more information to the error log, such as the value of time_freq_plot. + + X_data_original = torch.cat(X_data_original, 0) + + if standardize: + X_data_original = standard_scaling(X_data_original) # Standardize the data + + # Extract labels from CSV files + labels = extract_labels(UIDs, labels_path, segment_names) + + important_labels = [0, 1, 2, 3] # List of important labels + + # Initialize labels for segments as unlabeled (-1) + segment_labels = {segment_name: -1 for segment_name in segment_names} + + for UID in labels.keys(): + if UID not in UIDs: + # Skip UIDs that are not in the dataset + continue + + label_data, label_segment_names = labels[UID] + + for idx, segment_label in enumerate(label_data): + segment_name = label_segment_names[idx] + + # Change here: Only update labels for segments present in the loaded data + if segment_name in segment_names: + if segment_label in important_labels: + segment_labels[segment_name] = segment_label + else: + # Set labels that are not in the important list as -1 (Unlabeled) + segment_labels[segment_name] = -1 + + return X_data_original, segment_names, {seg: segment_labels[seg] for seg in segment_names} + +############################################################################################################################################## +############################################################################################################################################## +############################################################################################################################################## +log_progress("Code execution started.") +is_linux = False # Set to True if running on Linux, False if on Windows +is_hpc = False # Set to True if running on an HPC, False if on Windows +data_format = 'csv' # Choose 'csv' or 'png' +data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) + +# Load data for the training set +train_data, train_segment_names, train_labels = load_data_split( + data_path=data_path, labels_path=labels_path, UIDs=clinical_trial_train, standardize=True, data_format='csv' +) + +# Load data for the validation set +val_data, val_segment_names, val_labels = load_data_split( + data_path, labels_path, UIDs=clinical_trial_test, standardize=True, data_format='csv' +) + +# Load data for the test set (unlabeled) +test_data, test_segment_names, test_labels = load_data_split( + data_path, labels_path, UIDs=clinical_trial_unlabeled, standardize=True, data_format='csv' +) +############################################################################################################################################## +############################################################################################################################################## +############################################################################################################################################## +import pacmap +import numpy as np +import matplotlib.pyplot as plt + +def visualize_pacmap(X_transformed, y, title="Scatter Plot"): + fig, ax = plt.subplots() + scatter = ax.scatter(X_transformed[:, 0], X_transformed[:, 1], cmap="Spectral", c=list(y.values()), s=0.6) + ax.set_title(title) + plt.show() + +# Convert PyTorch tensor to NumPy array +train_data_numpy = train_data.numpy() +# Flatten the data +train_data_flattened = train_data_numpy.reshape((500, -1)) +# Example for the training set +train_embedding = pacmap.PaCMAP(n_components=4, n_neighbors=None, MN_ratio=0.5, FP_ratio=2.0) +train_X_transformed = train_embedding.fit_transform(train_data_flattened, init="pca") + +# Convert PyTorch tensor to NumPy array +val_data_numpy = val_data.numpy() +# Flatten the data +val_data_flattened = val_data_numpy.reshape((500, -1)) +# Example for the validation set +val_embedding = pacmap.PaCMAP(n_components=4, n_neighbors=None, MN_ratio=0.5, FP_ratio=2.0) +val_X_transformed = val_embedding.fit_transform(val_data_flattened, init="pca") + +# Convert PyTorch tensor to NumPy array +test_data_numpy = test_data.numpy() +# Flatten the data +test_data_flattened = test_data_numpy.reshape((500, -1)) +# Example for the test set (unlabeled) +test_embedding = pacmap.PaCMAP(n_components=4, n_neighbors=None, MN_ratio=0.5, FP_ratio=2.0) +test_X_transformed = test_embedding.fit_transform(test_data_flattened, init="pca") + +# Concatenate the labeled and unlabeled data for a combined visualization +combined_X = np.concatenate([train_data_flattened, val_data_flattened, test_data_flattened]) +combined_embedding = pacmap.PaCMAP(n_components=4, n_neighbors=None, MN_ratio=0.5, FP_ratio=2.0) +combined_X_transformed = combined_embedding.fit_transform(combined_X, init="pca") + +# Plot individual datasets +visualize_pacmap(train_X_transformed, train_labels, title="Training Set") +visualize_pacmap(val_X_transformed, val_labels, title="Validation Set") +visualize_pacmap(test_X_transformed, test_labels, title="Test Set (Unlabeled)") + +# Plot combined dataset with labeled and unlabeled points +visualize_pacmap(combined_X_transformed, np.concatenate([train_labels, val_labels, test_labels]), title="Combined Set") + +plt.show() + +############################################################################################################################################## +############################################################################################################################################## +############################################################################################################################################## + + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.distributions as dist +from torch.nn.functional import softplus +from sklearn.preprocessing import StandardScaler + +# Standardize the data +def standard_scaling(data): + scaler = StandardScaler() + data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) + return torch.Tensor(data) + +# Function to initialize the GMM parameters +def init_gmm_parameters(K, D): + miu_variational = torch.randn(K, D, requires_grad=True) + log_sigma_variational = torch.randn(K, D, requires_grad=True) + alpha_variational = torch.randn(K, requires_grad=True) + return miu_variational, log_sigma_variational, alpha_variational + +# Mean-Field Variational Inference for GMM with labeled/unlabeled data +def perform_mfvi(data, labels, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=False): + N, D = data.shape[0], data.shape[1] * data.shape[2] # Calculate feature dimension D + + # Initialize GMM parameters + miu_variational, log_sigma_variational, alpha_variational = init_gmm_parameters(K, D) + + # Define the optimizer for gradient descent + optimizer = torch.optim.Adam([miu_variational, log_sigma_variational, alpha_variational], lr=0.001) + + prev_elbo = float('-inf') + iteration = 0 + while True: + # Initialize gradients + optimizer.zero_grad() + + # Compute the Gaussian means and covariances from variational parameters + sigma_variational_sq = torch.exp(log_sigma_variational.clone()) + + # Calculate the responsibilities (E[zi]) + log_pi_variational = torch.log_softmax(alpha_variational, dim=0) + log_resp = log_pi_variational.unsqueeze(0) + multivariate_normal_log_pdf_MFVI(data, miu_variational, sigma_variational_sq) + log_resp_max = log_resp.max(dim=1, keepdim=True).values.clone() + resp = torch.exp(log_resp - log_resp_max).clone() + resp /= resp.sum(dim=1, keepdim=True) + + # Add labels to the ELBO for the labeled data + labeled_mask = labels >= 0 + elbo_labeled = -torch.sum(resp[labeled_mask] * log_resp[labeled_mask]) + torch.sum(resp[labeled_mask] * torch.log(resp[labeled_mask])) + + # Add entropy regularization for the unlabeled data + unlabeled_mask = labels == -1 + entropy_regularizer = -torch.sum(resp[unlabeled_mask] * torch.log(resp[unlabeled_mask])) + + # Combine the ELBO terms + elbo = elbo_labeled + entropy_regularizer + + # Perform backpropagation with retain_graph=True + elbo.backward(retain_graph=True) + + # Check for NaN values + if torch.isnan(miu_variational).any() or torch.isnan(log_sigma_variational).any() or torch.isnan(alpha_variational).any(): + print("NaN values detected in parameters. Stopping optimization.") + break + # Clip gradients to prevent exploding gradients + torch.nn.utils.clip_grad_norm_([miu_variational, log_sigma_variational, alpha_variational], max_norm=5.0) + + # Update the variational parameters + optimizer.step() + + # Print progress + if (iteration + 1) % 100 == 0: + print(f"Iteration {iteration + 1}/{n_optimization_iterations}") + + # Check for convergence based on change in ELBO and gradient clipping + if iteration > 0 and abs(elbo - prev_elbo) < convergence_threshold: + print(f"Converged after {iteration + 1} iterations") + break + + # Check for gradient clipping + if any(param.grad is not None and torch.isnan(param.grad).any() for param in [miu_variational, log_sigma_variational, alpha_variational]): + print("Gradient contains NaN values. Stopping optimization.") + break + + prev_elbo = elbo + iteration += 1 + + # Extract the learned parameters + miu = miu_variational.detach().numpy() + pi = torch.softmax(alpha_variational, dim=0).detach().numpy() + + return miu, pi, resp + +# This function computes the log PDF of a multivariate normal distribution for MFVI +def multivariate_normal_log_pdf_MFVI(x, mu, sigma_sq): + # x: Data points (N x D) + # mu: Means of the components (K x D) + # sigma_sq: Variances of the components (K x D) + N, D = x.shape[0], x.shape[1] # Corrected line + K, _ = mu.shape + + log_p = torch.empty(N, K, dtype=x.dtype, device=x.device) + for k in range(K): + cov_matrix = torch.diag(sigma_sq[k]) + mvn = dist.MultivariateNormal(mu[k], cov_matrix) + log_p[:, k] = mvn.log_prob(x.view(N, -1)) # Reshape x to match event_shape + + return log_p + +# Example usage for training, validation, and testing +# Assuming train_data, val_data, test_data are already loaded and standardized +train_labels_tensor = torch.Tensor(list(train_labels.values())) +val_labels_tensor = torch.Tensor(list(val_labels.values())) +test_labels_tensor = torch.Tensor(list(test_labels.values())) + +# Perform MFVI on training data +miu_train, pi_train, resp_train = perform_mfvi(train_data, train_labels_tensor, K=4, n_optimization_iterations=1000, run_until_convergence=True) +mfvi_labels_train = torch.argmax(resp_train, dim=1).numpy() + +# Perform MFVI on validation data +miu_val, pi_val, resp_val = perform_mfvi(val_data, val_labels_tensor, K=4, n_optimization_iterations=1000, run_until_convergence=True) +mfvi_labels_val = torch.argmax(resp_val, dim=1).numpy() +# Perform MFVI on test data (unlabeled) +miu_test, pi_test, resp_test = perform_mfvi(test_data, torch.full_like(test_labels_tensor, -1), K=4, n_optimization_iterations=1000, run_until_convergence=True) +mfvi_labels_test = torch.argmax(resp_test, dim=1).numpy() + +############################################################################################################################################## +############################################################################################################################################## +############################################################################################################################################## + +# Function to evaluate clustering and print multiple metrics +def evaluate_clustering(data, true_labels, predicted_labels): + ari = adjusted_rand_score(true_labels, predicted_labels) + ami = adjusted_mutual_info_score(true_labels, predicted_labels) + silhouette = silhouette_score(data, predicted_labels) + davies_bouldin = davies_bouldin_score(data, predicted_labels) + + print(f'Adjusted Rand Index (ARI): {ari}') + print(f'Adjusted Mutual Info (AMI): {ami}') + print(f'Silhouette Score: {silhouette}') + print(f'Davies-Bouldin Index: {davies_bouldin}') + return ari, ami, silhouette, davies_bouldin + +# Function to plot clusters +def plot_clusters(data, zi, title, save_path=None): + """ + Plot the data points colored by cluster assignment. + + Args: + data (torch.Tensor): The data points. + zi (torch.Tensor): The cluster assignments. + title (str): The title for the plot. + """ + unique_clusters = torch.unique(zi) + colors = plt.cm.viridis(torch.linspace(0, 1, len(unique_clusters))) + + plt.figure(figsize=(8, 6)) + for i, cluster in enumerate(unique_clusters): + cluster_data = data[zi == cluster] + plt.scatter(cluster_data[:, 0], cluster_data[:, 1], c=colors[i], label=f'Cluster {int(cluster)}') + + plt.title(title) + plt.legend() + plt.xlabel('Feature 1') + plt.ylabel('Feature 2') + + # Save the plot if save_path is provided + if save_path: + filename = title.replace(' ', '_') + ".png" + plt.savefig(os.path.join(save_path, filename)) + + plt.show() + +######---------------------- Train +title_mfvi_train = 'MFVI on training' + +# Evaluate and print metrics for MFVI +ari_mfvi_trai, ami_mfvi_train, silhouette_mfvi_train, davies_bouldin_mfvi_train = evaluate_clustering(train_data.view(train_data.size(0), -1).numpy(), np.array(list(train_labels.values())), mfvi_labels_train) + +# Plot clusters for MFVI +plot_clusters(train_data.view(train_data.size(0), -1).numpy(), torch.argmax(resp_train, dim=1), title_mfvi_train, save_path=saving_path) + +######-------------------- Val +title_mfvi_val = 'MFVI on validation' + +# Evaluate and print metrics for MFVI +ari_mfvi_val, ami_mfvi_training, silhouette_mfvi_training, davies_bouldin_mfvi_training = evaluate_clustering(val_data.view(val_data.size(0), -1).numpy(), np.array(list(val_labels.values())), mfvi_labels_val) + +# Plot clusters for MFVI +plot_clusters(val_data.view(val_data.size(0), -1).numpy(), torch.argmax(resp_val, dim=1), title_mfvi_val, save_path=saving_path) + +######---------------------- Test (Unlabeled) +title_mfvi_test = 'MFVI on testing Unlabeled' + +# Evaluate and print metrics for MFVI +ari_mfvi_test, ami_mfvi_test, silhouette_mfvi_test, davies_bouldin_mfvi_test = evaluate_clustering(test_data.view(test_data.size(0), -1).numpy(), np.array(list(test_labels.values())), mfvi_labels_test) + +# Plot clusters for MFVI +plot_clusters(test_data.view(test_data.size(0), -1).numpy(), torch.argmax(resp_test, dim=1), title_mfvi_test, save_path=saving_path) diff --git a/project_NVIL.py b/project_NVIL.py new file mode 100644 index 0000000..7996ed7 --- /dev/null +++ b/project_NVIL.py @@ -0,0 +1,1643 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Nov 25 16:47:46 2023 + +@author: lrm22005 +""" +import numpy as np +import torch +import torch.nn as nn +import torch.distributions as td +import torch.nn.functional as F +import torch.distributions as td +import torch.optim as opt +from torchvision import transforms +from torchvision.transforms import ToTensor +from torch.utils.data import random_split, Dataset +from torch.utils.data.dataloader import DataLoader +from torch.utils.data import DataLoader, TensorDataset +from torchvision.utils import make_grid +import matplotlib.pyplot as plt +from tqdm.auto import tqdm +from collections import defaultdict, OrderedDict +import random + + +from functools import partial +from itertools import chain + +def seed_all(seed=42): + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + +seed_all() + +# Generative Model (gen_model): + +# Prior Net: We begin by specifying the component that parameterises the prior pz + +# I need that this not be fixed, I'm dealing with the image representation of signals in Time Frequency domain, for that reason is important to define an adecuate prior. In total the data has 4 labels 0,1,2,3 and one additional label that I added to the data that is unlabeled (-1). + + +# This Prior is fixed but you can implement a better prior. We know that the data labeled is distributed in the following way it mean the data would has a similar distribution +# Normal sinus rhythm (NSR, class 0) +# 133,149 +# Atrial fibrillation (AF, class 1) +# 24,555 +# Premature atrial contraction / premature ventricular contraction (PAC/PVC, class 2) +# 19,491 +# Supraventricular tachycardia (SVT, class 3) +# 432 +# Total +# 177,627 + +import os +import torch +import pandas as pd +import numpy as np +from torch.utils.data import DataLoader, Dataset, TensorDataset +from torchvision.transforms import ToTensor +from sklearn.preprocessing import StandardScaler +from PIL import Image + +class CustomDataset(Dataset): + def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv'): + self.data_path = data_path + self.labels_path = labels_path + self.UIDs = UIDs + self.standardize = standardize + self.data_format = data_format + self.transforms = ToTensor() + + self.X_data_original, self.segment_names, self.labels = self.load_split_data() + + def __len__(self): + return len(self.segment_names) + + def __getitem__(self, idx): + segment_name = self.segment_names[idx] + time_freq_tensor = self.X_data_original[idx] + + label = self.labels[segment_name] + + return {'data': time_freq_tensor.unsqueeze(0), 'label': label, 'segment_name': segment_name} + + def load_split_data(self): + X_data_original = [] # Store original data without standardization + segment_names = [] + + for UID in self.UIDs: + data_path_UID = os.path.join(self.data_path, UID) + dir_list_seg = os.listdir(data_path_UID) + + for seg in dir_list_seg[:500]: # Limiting to the first 500 segments + seg_path = os.path.join(data_path_UID, seg) + + try: + if self.data_format == 'csv' and seg.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif self.data_format == 'png' and seg.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + continue # Skip other file formats + + X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data + + # Extract and store segment names + # Change here: Use the segment name from the CSV file directly + segment_names.append(seg.split('_filt')[0]) + + except Exception as e: + print(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}") + # You can also add more information to the error log, such as the value of time_freq_plot. + + X_data_original = torch.cat(X_data_original, 0) + + if self.standardize: + X_data_original = self.standard_scaling(X_data_original) # Standardize the data + + # Extract labels from CSV files + labels = self.extract_labels() + + important_labels = [0, 1, 2, 3] # List of important labels + + # Initialize labels for segments as unlabeled (-1) + segment_labels = {segment_name: -1 for segment_name in segment_names} + + for UID in labels.keys(): + if UID not in self.UIDs: + # Skip UIDs that are not in the dataset + continue + + label_data, label_segment_names = labels[UID] + + for idx, segment_label in enumerate(label_data): + segment_name = label_segment_names[idx] + + # Change here: Only update labels for segments present in the loaded data + if segment_name in segment_names: + if segment_label in important_labels: + segment_labels[segment_name] = segment_label + else: + # Set labels that are not in the important list as -1 (Unlabeled) + segment_labels[segment_name] = -1 + + return X_data_original, segment_names, segment_labels + + def extract_labels(self): + labels = {} + for UID in self.UIDs: + label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv") + if os.path.exists(label_file): + label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) + label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) + labels[UID] = (label_data['label'].values, label_segment_names.values) + + return labels + + def standard_scaling(self, data): + scaler = StandardScaler() + data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) + return torch.Tensor(data) + +def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=True, data_format='csv'): + dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + return dataloader + +is_linux = False # Set to True if running on Linux, False if on Windows +is_hpc = False # Set to True if running on an HPC, False if on Windows +data_format = 'csv' # Choose 'csv' or 'png' +data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) + +# Example usage: +batch_size = 32 +train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size) +val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size) +test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size) + + +def imshow(image, ax=None, title=None, normalize=True): + """Imshow for Tensor.""" + if ax is None: + fig, ax = plt.subplots() + image = image.numpy().transpose((1, 2, 0)) + + if normalize: + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + image = std * image + mean + image = np.clip(image, 0, 1) + + ax.imshow(image) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['left'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.tick_params(axis='both', length=0) + ax.set_xticklabels('') + ax.set_yticklabels('') + + return ax + +# change this to the trainloader or testloader +data_iter = iter(train_loader) +data = next(data_iter) +fig, axes = plt.subplots(figsize=(20,15), ncols=6) +for ii in range(6): + images, label, names = data["data"], data["label"], data["segment_name"] + ax = axes[ii] +# helper.imshow(images[ii], ax=ax, normalize=False) + imshow(images[ii], ax=ax, normalize=False) + + + + +fig, axes = plt.subplots(figsize=(20,15), ncols=6) + +for i, batch in enumerate(train_loader): + images, label, names = batch["data"], batch["label"], batch["segment_name"] + ax = axes[i] +# helper.imshow(images[ii], ax=ax, normalize=False) + imshow(images[i], ax=ax, normalize=False) + if i==6: + break + + + + + +class PriorNet(nn.Module): + """ + An NN that parameterises a prior distribution. + + For this lab, our priors are fixed, so this NN's forward pass + simply returns a fixed prior with a given batch_shape. + """ + + def __init__(self, outcome_shape: tuple): + """ + outcome_shape: this is the shape of a single outcome + if you use a single integer k, we will turn it into (k,) + """ + super().__init__() + if isinstance(outcome_shape, int): + outcome_shape = (outcome_shape,) + self.outcome_shape = outcome_shape + + def forward(self, batch_shape): + """ + Returns a td object for the batch. + """ + raise NotImplementedError("Implement me!") + +class MoGPriorNet(PriorNet): + """ + For z a K-dimensional code: + + p(z|w_1...w_C, u_1...u_C, s_1...s_C) + = \sum_c w_c prod_k Normal(z[k]|u[c], s[c]^2) + """ + + def __init__(self, outcome_shape, num_components, lbound=-10, rbound=10): + super().__init__(outcome_shape) + # [C] + self.logits = nn.Parameter(torch.rand(num_components, requires_grad=True), requires_grad=True) + # (C,) + outcome_shape + shape = (num_components,) + self.outcome_shape + self.locs = nn.Parameter(torch.rand(shape, requires_grad=True), requires_grad=True) + self.scales = nn.Parameter(1 + torch.rand(shape, requires_grad=True), requires_grad=True) + self.num_components = num_components + + def forward(self, batch_shape): + # e.g., with batch_shape (B,) and outcome_shape (K,) this is + # [B, C, K] + shape = batch_shape + (self.num_components,) + self.outcome_shape + # we wrap around td.Independent to obtain a pdf over multivariate draws + # (note that C is not part of the event_shape, thus td.Independent will + # should not treat that dimension as part of the outcome) + # in our example, a draw from independent would return [B, C] draws of K-dimensional outcomes + comps = td.Independent(td.Normal(loc=self.locs.expand(shape), scale=self.scales.expand(shape)), len(self.outcome_shape)) + # a batch of component selectors + pc = td.Categorical(logits=self.logits.expand(batch_shape + (self.num_components,))) + # and finally, a mixture + return td.MixtureSameFamily(pc, comps) + +def test_priors(batch_size=2, latent_dim=3, num_comps=5): + prior_net = MoGPriorNet(latent_dim, num_comps) + print("\nMixture of Gaussian") + print(" trainable parameters") + print(list(prior_net.parameters())) + print(f" outcome_shape={prior_net.outcome_shape}") + p = prior_net(batch_shape=(batch_size,)) + print(f" distribution: {p}") + z = p.sample() + print(f" sample: {z}") + print(f" shapes: sample={z.shape} log_prob={p.log_prob(z).shape}") + + +test_priors() + + +# Conditional probability distributions +# Observation Model (cpd_net). + +class CPDNet(nn.Module): + """ + Let L be a choice of distribution + and x ~ L is an outcome with shape outcome_shape + + This is an NN whose forward method maps from a number of inputs to the + parameters of L's pmf/pdf and returns a torch.distributions + object representing L's pmf/pdf. + """ + + def __init__(self, outcome_shape): + """ + outcome_shape: this is the shape of a single outcome + if you use a single integer k, we will turn it into (k,) + """ + super().__init__() + if isinstance(outcome_shape, int): + outcome_shape = (outcome_shape,) + self.outcome_shape = outcome_shape + + def forward(self, inputs): + """ + Return a torch.distribution object predicted from `inputs`. + + inputs: a tensor with shape batch_shape + (num_inputs,) + """ + raise NotImplementedError("Implemented me") + +# Observational model +class ReshapeLast(nn.Module): + """ + Helper layer to reshape the rightmost dimension of a tensor. + + This can be used as a component of nn.Sequential. + """ + + def __init__(self, shape: tuple): + """ + shape: desired rightmost shape + """ + super().__init__() + self._shape = shape + + def forward(self, input): + # reshapes the last dimension into self.shape + return input.reshape(input.shape[:-1] + self._shape) + + +# class MySequential(nn.Sequential): +# """ +# This is a version of nn.Sequential that works with structured batches +# (i.e., batches that have multiple dimensions) +# even when some of the nn layers in it does not. + +# The idea is to just wrap nn.Sequential around two calls to reshape +# which remove and restore the batch dimensions. +# """ + +# def __init__(self, *args, event_dims=1): +# super().__init__(*args) +# self._event_dims = event_dims + +# def forward(self, input): +# # memorize batch shape +# batch_shape = input.shape[:-self._event_dims] +# # memorize latent shape +# event_shape = input.shape[-self._event_dims:] +# # flatten batch shape and obtain outputs +# output = super().forward(input.reshape((-1,) + event_shape)) +# # restore batch shape +# return output.reshape(batch_shape + output.shape[1:]) + +class MySequential(nn.Sequential): + """ + This is a version of nn.Sequential that works with structured batches + (i.e., batches that have multiple dimensions) + even when some of the nn layers in it do not. + + The idea is to just wrap nn.Sequential around two calls to reshape + which remove and restore the batch dimensions. + """ + + def __init__(self, *args, event_dims=1): + super().__init__(*args) + self._event_dims = event_dims + + def forward(self, input): + # memorize batch shape + batch_shape = input.shape[:-self._event_dims] + # memorize latent shape + event_shape = input.shape[-self._event_dims:] + + # flatten batch shape and obtain outputs + output = super().forward(input.reshape((-1,) + event_shape)) + print("Shape after FlattenImage:", input.shape) + + # restore batch shape + output_shape = batch_shape + output.shape[1:] + return output.reshape(output_shape) + +# class MySequential(nn.Sequential): +# """ +# This is a version of nn.Sequential that works with structured batches +# (i.e., batches that have multiple dimensions) +# even when some of the nn layers in it does not. + +# The idea is to just wrap nn.Sequential around two calls to reshape +# which remove and restore the batch dimensions. +# """ + +# def __init__(self, *args, event_dims=1): +# super().__init__(*args) +# self._event_dims = event_dims + +# def forward(self, input): +# # memorize batch shape +# batch_shape = input.shape[:-self._event_dims] +# # memorize latent shape +# event_shape = input.shape[-self._event_dims:] + +# # Print the shape of the tensor after the FlattenImage layer +# input = super().forward(input) +# print("Shape after FlattenImage:", input.shape) + +# # flatten batch shape and obtain outputs +# output = input.reshape((-1,) + event_shape) +# # restore batch shape +# return output.reshape(batch_shape + output.shape[1:]) + +def build_cnn_decoder(latent_size, num_channels, width=64, height=64, hidden_size=1024, p_drop=0.): + """ + Map the latent code to a tensor with shape [num_channels, width, height]. + + latent_size: size of latent code + num_channels: number of channels in the output + width: width of the output image + height: height of the output image + hidden_size: we first map from latent_size to hidden_size and + then use transposed 2d convolutions to [num_channels, width, height] + p_drop: dropout rate before linear layers + """ + decoder = MySequential( + nn.Dropout(p_drop), + nn.Linear(latent_size, hidden_size), + ReshapeLast((hidden_size, 1, 1)), + nn.ConvTranspose2d(hidden_size, 128, 5, 2), + nn.ReLU(), + nn.ConvTranspose2d(128, 64, 5, 2), + nn.ReLU(), + nn.ConvTranspose2d(64, 32, 6, 2), + nn.ReLU(), + nn.ConvTranspose2d(32, num_channels, 6, 2, output_padding=1), # Adjusted for variable input size + nn.Upsample(size=(width, height), mode='bilinear', align_corners=False), # Upsample to the desired size + event_dims=1 + ) + return decoder + +# # Test the modified function +img_shape = train_data[0].shape +num_channels, width, height = img_shape[0], img_shape[1], img_shape[2] +latent_size = 128 + +output_shape = build_cnn_decoder(latent_size=latent_size, num_channels=num_channels, width=width, height=height)( + torch.zeros((128, latent_size)) +).shape + +print(output_shape) + +# note that because we use MySequential, +# we can have a batch of [3, 5] assignments +# (this is useful, for example, when we have multiple draws of the latent +# variable for each of the data points in the batch) +# build_cnn_decoder(latent_size=latent_size, num_channels=num_channels, width=width, height=height)( +# torch.zeros((3, 5, latent_size)) +# ).shape + +class ContinuousImageModel(CPDNet): + + def __init__(self, num_channels, width, height, latent_size, decoder_type=build_cnn_decoder, p_drop=0.1): + super().__init__((num_channels, width, height)) + self.decoder = decoder_type( + latent_size=latent_size, + num_channels=num_channels, + width=width, + height=height, + p_drop=p_drop + ) + + def forward(self, z): + """ + Return the cpd X|Z=z + + z: batch_shape + (latent_dim,) + """ + # batch_shape + (num_channels, width, height) + h = self.decoder(z) + return td.Independent(td.ContinuousBernoulli(logits=h), len(self.outcome_shape)) + +obs_model = ContinuousImageModel( + num_channels=img_shape[0], + width=img_shape[1], + height=img_shape[2], + latent_size=128, + p_drop=0.1, + decoder_type=build_cnn_decoder +) +print(obs_model) +# a batch of five zs is mapped to 5 distributions over [1,64,64]-dimensional +# binary tensors +print(obs_model(torch.zeros([128, 128]))) + +# Joint distribution + +class JointDistribution(nn.Module): + """ + A wrapper to combine a prior net and a cpd net into a joint distribution. + """ + + def __init__(self, prior_net: PriorNet, cpd_net: CPDNet): + """ + prior_net: object to parameterise p_Z + cpd_net: object to parameterise p_{X|Z=z} + """ + super().__init__() + self.prior_net = prior_net + self.cpd_net = cpd_net + + def prior(self, shape): + return self.prior_net(shape) + + def obs_model(self, z): + return self.cpd_net(z) + + def sample(self, shape): + """ + Return z via prior_net(shape).sample() + and x via cpd_net(z).sample() + """ + pz = self.prior_net(shape) + z = pz.sample() + px_z = self.cpd_net(z) + x = px_z.sample() + return z, x + + def log_prob(self, z, x): + """ + Assess the log density of the joint outcome. + """ + batch_shape = z.shape[:-len(self.prior_net.outcome_shape)] + pz = self.prior_net(batch_shape) + px_z = self.cpd_net(z) + return pz.log_prob(z) + px_z.log_prob(x) + + def log_marginal(self, x, enumerate_fn): + """ + Return log marginal density of x. + + enumerate_fn: function that enumerates the support of the prior + (this is needed for marginalisation p(x) = \int p(z, x) dz) + + This only really makes sense if the support is a + (small) countably finite set. In such cases, you can use + enumerate=lambda p: p.enumerate_support() + which is supported, for example, by Categorical and OneHotCategorical. + + If the support is discrete (eg, bit vectors) you can still dare to + enumerate it explicitly, but you will need to write cutomised code, + as torch.distributions will not offer that functionality for you. + + If the support is uncountable, countably infinite, or just large + anyway, you need approximate tools (such as VI, importance sampling, etc) + """ + batch_shape = x.shape[:-len(self.cpd_net.outcome_shape)] + pz = self.prior_net(batch_shape) + log_joint = [] + # (support_size,) + batch_shape + z = enumerate_fn(pz) + px_z = self.cpd_net(z) + # (support_size,) + batch_shape + log_joint = pz.log_prob(z) + px_z.log_prob(x.unsqueeze(0)) + # batch_shape + return torch.logsumexp(log_joint, 0) + + def posterior(self, x, enumerate_fn): + """ + Return the posterior distribution Z|X=x. + + As the code is discrete, we return a discrete distribution over + the complete space of all possible latent codes. This is done via + exhaustive enumeration provided by `enumerate_fn`. + """ + batch_shape = x.shape[:-len(self.cpd_net.outcome_shape)] + pz = self.prior_net(batch_shape) + # (support_size,) + batch_shape + z = enumerate_fn(pz) + px_z = self.cpd_net(z) + # (support_size,) + batch_shape + log_joint = pz.log_prob(z) + px_z.log_prob(x.unsqueeze(0)) + # batch_shape + (support_size,) + log_joint = torch.swapaxes(log_joint, 0, -1) + return td.Categorical(logits=log_joint) + + def naive_lowerbound(self, x, num_samples: int): + """ + Return an MC lowerbound on log marginal density of x: + log p(x) >= 1/S \sum_s log p(x|z[s]) + with z[s] ~ p_Z + """ + batch_shape = x.shape[:-len(self.cpd_net.outcome_shape)] + pz = self.prior_net(batch_shape) + # (num_samples,) + batch_shape + prior_outcome_shape + log_probs = [] + # I'm using a for loop, but note that with enough GPU memory + # one could parallelise this step + for z in pz.sample((num_samples,)): + px_z = self.cpd_net(z) + log_probs.append(px_z.log_prob(x)) + # (num_samples,) + batch_shape + log_probs = torch.stack(log_probs) + # batch_shape + return torch.mean(log_probs, 0) + +def test_joint_dist(latent_size=10, num_comps=3, data_shape=(1, 128, 128), batch_size=2, hidden_size=32): + p = JointDistribution( + prior_net=MoGPriorNet(latent_size, num_comps), + cpd_net=ContinuousImageModel( + num_channels=data_shape[0], + width=data_shape[1], + height=data_shape[2], + latent_size=latent_size, + decoder_type=build_cnn_decoder + ) + ) + print("Model for continuous data") + print(p) + z, x = p.sample((batch_size,)) + print("sampled z") + print(z) + print("sampled x") + print(x) + print("MC lowerbound") + print(" 1:", p.naive_lowerbound(x, 10)) + print(" 2:", p.naive_lowerbound(x, 10)) + +test_joint_dist(10) + + +##### PART 2 + +# Inference Model (inf_model): + +# Encoder: Either a feedforward neural network (FFNN) or a convolutional neural network (CNN) for mapping input images to a latent representation. In this case we choose CNN + +class FlattenImage(nn.Module): + def forward(self, input): + return input.reshape(input.shape[:-3] + (-1,)) + +def build_cnn_encoder(num_channels, width=64, height=64, output_size=1024, p_drop=0.): + # if width != 64: + # raise ValueError("The width is hardcoded") + # if height != 64: + # raise ValueError("The height is hardcoded") + # if output_size != 1024: + # raise ValueError("The output_size is hardcoded") + # TODO: change the architecture so width, height and output_size are not hardcoded + encoder = MySequential( + nn.Conv2d(num_channels, 32, 4, 2), + nn.LeakyReLU(0.2), + nn.Conv2d(32, 64, 4, 2), + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2), + nn.LeakyReLU(0.2), + nn.Conv2d(128, 256, 4, 2), + nn.LeakyReLU(0.2), + nn.Conv2d(256, 512, 2, 2), + nn.LeakyReLU(0.2), + nn.Conv2d(512, output_size, 2, 2), + nn.LeakyReLU(0.2), + FlattenImage(), + event_dims=3 + ) + return encoder + +# Example usage with width=128, height=128, and variable output_size +width = 128 +height = 128 +output_size = 1024 # You can set this to any desired output size +encoder = build_cnn_encoder(num_channels=1, width=width, height=height, output_size=output_size) +# a batch of five [1, 128, 128]-dimensional images is encoded into +# five `output_size`-dimensional vectors +encoder(torch.zeros((5, 1, width, height))).shape +# and, again, since we use MySequential we can have structured batches +# (here trying with (3,5)) +build_cnn_encoder(num_channels=1, width=width, height=height, output_size=output_size)(torch.zeros((3, 5, 1, 128, 128))).shape + + +# Mixture of Gaussian mean fields + +# This can also be used to parameterise a cpd over real vectors of fixed dimensionality, but it achieves a more complex density (e.g., multimodal). + +# Conditional Probability Distribution (cpd_net): For Mixture of Gaussian Models + +class MoGCPDNet(CPDNet): + """ + Output distribution is a mixture of products of Gaussian distributions + """ + + def __init__(self, outcome_shape, num_inputs: int, hidden_size: int=None, p_drop: float=0., num_components=2): + """ + outcome_shape: shape of the outcome (int or tuple) + if int, we turn it into a singleton tuple + num_inputs: rightmost dimensionality of the inputs to forward + hidden_size: size of hidden layers for the CPDNet (use None to skip) + p_drop: configure dropout before every Linear layer + num_components: number of Gaussians to be mixed + """ + super().__init__(outcome_shape) + self.num_components = num_components + + num_outputs = num_components * np.prod(self.outcome_shape) + + if hidden_size: + self.encoder = nn.Sequential( + nn.Dropout(p_drop), + nn.Linear(num_inputs, hidden_size), + nn.ReLU() + ) + else: + self.encoder = nn.Identity() + hidden_size = num_inputs + + self.locs = nn.Sequential( + nn.Dropout(p_drop), + nn.Linear(hidden_size, num_outputs), + ReshapeLast((num_components,) + self.outcome_shape) + ) + self.scales = nn.Sequential( + nn.Dropout(p_drop), + nn.Linear(hidden_size, num_outputs), + nn.Softplus(), # we use the softplus activations for the scales + ReshapeLast((num_components,) + self.outcome_shape) + ) + self.logits = nn.Sequential( + nn.Dropout(p_drop), + nn.Linear(hidden_size, num_components), + ReshapeLast((num_components,)) + ) + + def forward(self, inputs): + h = self.encoder(inputs) + comps = td.Independent(td.Normal(loc=self.locs(h), scale=self.scales(h)), len(self.outcome_shape)) + pc = td.Categorical(logits=self.logits(h)) + return td.MixtureSameFamily(pc, comps) + +def test_cpds(outcome_shape, num_comps=2, batch_size=3, input_dim=5, hidden_size=2): + + cpd_net = MoGCPDNet(outcome_shape, num_inputs=input_dim, hidden_size=hidden_size, num_components=num_comps) + print("\nMixture of Gaussians") + print(cpd_net) + print(f" outcome_shape={cpd_net.outcome_shape}") + inputs = torch.from_numpy(np.random.uniform(size=(batch_size, input_dim))).float() + print(f" shape of inputs: {inputs.shape}") + p = cpd_net(inputs) + print(f" distribution: {p}") + z = p.sample() + print(f" sample: {z}") + print(f" shapes: sample={z.shape} log_prob={p.log_prob(z).shape}") + + +# Try a few +# test_cpds(12) +# test_cpds(12, hidden_size=None) +# your latent code could be a metrix (we talk about it as a "vector" for convenience) +# test_cpds((4, 5)) + +# Last, but certainly not least, we can combine our encoder and a choice of CPD net. + +class InferenceModel(CPDNet): + + def __init__( + self, cpd_net_type, + latent_size, num_channels=1, width=128, height=128, + hidden_size=1024, p_drop=0., + encoder_type=build_cnn_encoder): + + super().__init__(latent_size) + + self.latent_size = latent_size + # encodes an image to a hidden_size-dimensional vector + self.encoder = encoder_type( + num_channels=num_channels, + width=width, + height=height, + output_size=hidden_size, + p_drop=p_drop + ) + # maps from a hidden_size-dimensional encoding + # to a cpd for Z|X=x + self.cpd_net = cpd_net_type( + latent_size, + num_inputs=hidden_size, + hidden_size=2*latent_size, + p_drop=p_drop + ) + + def forward(self, x): + h = self.encoder(x) + return self.cpd_net(h) + +InferenceModel(partial(MoGCPDNet, num_components=3), latent_size=10)(torch.zeros(5, 1, 128, 128)) +InferenceModel(partial(MoGCPDNet, num_components=3), latent_size=10)(torch.zeros(5, 1, 128, 128)) + + +# Neural Variational Inference + +class VarianceReduction(nn.Module): + """ + We will be using simple forms of control variates for variance reduction. + These are transformations of the reward that are independent of the sampled + latent variable, but they can, in principle, depend on x, and on the + parameters of the generative and inference model. + + Some of these are trainable components, thus they also contribute to the loss. + """ + + def __init__(self): + super().__init__() + + def forward(self, r, x, q, r_fn): + """ + Return the transformed reward and a contribution to the loss. + + r: a batch of rewards + x: a batch of observations + q: policy + r_fn: reward function + """ + return r, torch.zeros_like(r) + +# Now we can work on our general NVIL model. The following class implements the NVIL objective as well as a lot of helper code to manipulate the model components in interesting ways (e.g., sampling, sampling conditionally, estimating marginal density, etc.) + +import torch +from torch.distributions import kl_divergence, Categorical + +def kl_mixture_same_family(p, q, num_samples=10): + """ + Monte Carlo approximation of KL divergence between two MixtureSameFamily distributions. + + Parameters: + - p (MixtureSameFamily): First distribution + - q (MixtureSameFamily): Second distribution + - num_samples (int): Number of samples for Monte Carlo approximation + + Returns: + - kl_div (float): Approximated KL divergence + """ + # Sample from both distributions + p_samples = p.sample((num_samples,)) + q_samples = q.sample((num_samples,)) + + # Evaluate log probabilities + log_p = p.log_prob(p_samples) + log_q = q.log_prob(q_samples) + + # Compute Monte Carlo approximation of KL divergence + kl_div = torch.mean(log_p - log_q) + + return kl_div + +class MixtureSameFamilyKLDivergence: + """ + Computes KL Divergence between two MixtureSameFamily distributions. + """ + + def __init__(self, p, q, num_samples=10): + self.p = p + self.q = q + self.num_samples = num_samples + + def _compute_kl(self): + # Use the provided kl_mixture_same_family function + return kl_mixture_same_family(self.p, self.q, num_samples=self.num_samples) + +class NVIL(nn.Module): + """ + A generative model p(z)p(x|z) and an approximation q(z|x) to that + model's true posterior. + + The approximation is estimated to maximise the ELBO, and so is the joint + distribution. + """ + + def __init__(self, gen_model: JointDistribution, inf_model: InferenceModel, cv_model: VarianceReduction): + """ + gen_model: p(z)p(x|z) + inf_model: q(z|x) which approximates p(z|x) + cv_model: optional transformations of the reward + """ + super().__init__() + self.gen_model = gen_model + self.inf_model = inf_model + self.cv_model = cv_model + + def gen_params(self): + return self.gen_model.parameters() + + def inf_params(self): + return self.inf_model.parameters() + + def cv_params(self): + return self.cv_model.parameters() + + def sample(self, batch_size, sample_size=None, oversample=False): + """ + A sample from the joint distribution: + z ~ prior + x|z ~ obs model + batch_size: number of samples in a batch + sample_size: if None, the output tensor has shape [batch_size] + data_shape + if 1 or more, the output tensor has shape [sample_size, batch_size] + data_shape + while batch_size controls a parallel computation, + sample_size controls a sequential computation (a for loop) + oversample: if True, samples z (batch_size times), hold it fixed, + and sample x (sample_size times) + """ + pz = self.gen_model.prior((batch_size,)) + samples = [None] * (sample_size or 1) + px_z = self.gen_model.obs_model(pz.sample()) if oversample else None + for k in range(sample_size or 1): + if not oversample: + px_z = self.gen_model.obs_model(pz.sample()) + samples[k] = px_z.sample() + x = torch.stack(samples) + return x if sample_size else x.squeeze(0) + + def cond_sample(self, x, sample_size=None, oversample=False): + """ + Condition on x and draw a sample: + z|x ~ inf model + x'|z ~ obs model + + x: a batch of seed data samples + sample_size: if None, the output tensor has shape [batch_size] + data_shape + if 1 or more, the output tensor has shape [sample_size, batch_size] + data_shape + sample_size controls a sequential computation (a for loop) + oversample: if True, samples z (batch_size times), hold it fixed, + and sample x' (sample_size times) + """ + qz = self.inf_model(x) + samples = [None] * (sample_size or 1) + px_z = self.gen_model.obs_model(qz.sample()) if oversample else None + for k in range(sample_size or 1): + if not oversample: + px_z = self.gen_model.obs_model(qz.sample()) + samples[k] = px_z.sample() + x = torch.stack(samples) + return x if sample_size else x.squeeze(0) + + def log_prob(self, z, x): + """ + The log density of the joint outcome under the generative model + z: [batch_size, latent_dim] + x: [batch_size] + data_shape + """ + return self.gen_model.log_prob(z=z, x=x) + + def DRL(self, x, sample_size=None): + """ + MC estimates of a model's + * distortion D + * rate R + * and log-likelihood L + The estimates are based on single data points + but multiple latent samples. + + x: batch_shape + data_shape + sample_size: if 1 or more, we use multiple samples + sample_size controls a sequential computation (a for loop) + """ + sample_size = sample_size or 1 + obs_dims = len(self.gen_model.cpd_net.outcome_shape) + batch_shape = x.shape[:-obs_dims] + with torch.no_grad(): + qz = self.inf_model(x) + pz = self.gen_model.prior(batch_shape) + try: # not every design admits tractable KL + R = td.kl_divergence(qz, pz) + except NotImplementedError: + # MC estimation of KL(q(z|x)||p(z)) + z = qz.sample((sample_size,)) + R = (qz.log_prob(z) - pz.log_prob(z)).mean(0) + D = 0 + ratios = [None] * sample_size + for k in range(sample_size): + z = qz.sample() + px_z = self.gen_model.obs_model(z) + ratios[k] = pz.log_prob(z) + px_z.log_prob(x) - qz.log_prob(z) + D = D - px_z.log_prob(x) + ratios = torch.stack(ratios, dim=-1) + L = torch.logsumexp(ratios, dim=-1) - np.log(sample_size) + D = D / sample_size + return D, R, L + + def elbo(self, x, sample_size=None): + """ + An MC estimate of ELBO = -D -R + + x: [batch_size] + data_shape + sample_size: if 1 or more, we use multiple samples + sample_size controls a sequential computation (a for loop) + """ + D, R, _ = self.DRL(x, sample_size=sample_size) + return -D -R + + def log_prob_estimate(self, x, sample_size=None): + """ + An importance sampling estimate of log p(x) + + x: [batch_size] + data_shape + sample_size: if 1 or more, we use multiple samples + sample_size controls a sequential computation (a for loop) + """ + _, _, L = self.DRL(x, sample_size=sample_size) + return L + + def forward(self, x, sample_size=None, rate_weight=1.): + """ + A surrogate for an MC estimate of - grad ELBO + + x: [batch_size] + data_shape + sample_size: if 1 or more, we use multiple samples + sample_size controls a sequential computation (a for loop) + cv: optional module for variance reduction + """ + sample_size = sample_size or 1 + obs_dims = len(self.gen_model.cpd_net.outcome_shape) + batch_shape = x.shape[:-obs_dims] + + qz = self.inf_model(x) + pz = self.gen_model.prior(batch_shape) + + # we can *always* make use of the score function estimator (SFE) + use_sfe = True + + # these 3 log densities will contribute to the different parts of the objective + log_p_x_z = 0. + log_p_z = 0. + log_q_z_x = 0. + + # these quantities will help us compute the SFE part of the objective + # (if needed) + sfe = 0 + reward = 0 + cv_reward = 0 + raw_r = 0 + cv_loss = 0 + + for _ in range(sample_size): + + # Obtain a sample + if qz.has_rsample: # this is how td objects tell us whether they are continuously reparameterisable + z = qz.rsample() + use_sfe = False # with path derivatives, we do not need SFE + else: + z = qz.sample() + + # Parameterise the observational model + px_z = self.gen_model.obs_model(z) + + # Compute all three relevant densities: + # p(x|z,theta) + log_p_x_z = log_p_x_z + px_z.log_prob(x) + # q(z|x,lambda) + log_q_z_x = log_q_z_x + qz.log_prob(z) + # p(z|theta) + log_p_z = log_p_z + pz.log_prob(z) + + # Compute the "reward" for SFE + raw_r = log_p_x_z + log_p_z - log_q_z_x + + # Apply variance reduction techniques + r, l = self.cv_model(raw_r.detach(), x=x, q=qz, r_fn=lambda a: self.gen_model(a).log_prob(x)) + cv_loss = cv_loss + l + + # SFE part for updating lambda + sfe = sfe + r.detach() * qz.log_prob(z) + + # Compute the sample mean for the different terms + sfe = (sfe / sample_size) + cv_loss = cv_loss / sample_size + log_p_x_z = log_p_x_z / sample_size + log_p_z = log_p_z / sample_size + log_q_z_x = log_q_z_x / sample_size + + D = - log_p_x_z + try: # not every design admits tractable KL + R = td.kl_divergence(qz, pz) + except NotImplementedError: + R = log_q_z_x - log_p_z + + if use_sfe: + # the first two terms update theta + # the last term updates lambda + elbo_grad_surrogate = log_p_x_z + log_p_z + sfe + # note that the term (log_p_x_z + log_p_z) is also part of sfe + # but there it is detached, meaning that it won't contribute to + # grad theta + else: + # without SFE, we can use the classic form of the ELBO + elbo_grad_surrogate = -D - R + + loss = -elbo_grad_surrogate + cv_loss + + return {'loss': loss.mean(0), 'ELBO': (-D -R).mean(0).item(), 'D': D.mean(0).item(), 'R': R.mean(0).item(), 'cv_loss': cv_loss.mean(0).item()} + +num_classes = 4 # Change this to the actual number of classes +vae = NVIL( + JointDistribution( + MoGPriorNet(128, 128), + MoGCPDNet( + outcome_shape=img_shape, + num_inputs=batch_size, + hidden_size=32, + num_components=2 + ) + ), + InferenceModel( + cpd_net_type=partial(MoGCPDNet), + latent_size=128, + num_channels=img_shape[0], + width=img_shape[1], + height=img_shape[2], + ), + VarianceReduction() +) +vae + +for x, y in train_loader: + print('x.shape:', x.shape) + print(vae(x)) + break + +# Training algorithm + +class OptCollection: + + def __init__(self, gen, inf, cv=None): + self.gen = gen + self.inf = inf + self.cv = cv + + def zero_grad(self): + self.gen.zero_grad() + self.inf.zero_grad() + if self.cv: + self.cv.zero_grad() + + def step(self): + self.gen.step() + self.inf.step() + if self.cv: + self.cv.step() + +from collections import defaultdict, OrderedDict +from tqdm.auto import tqdm + + +def assess(model, sample_size, dl, device): + """ + Wrapper for estimating a model's ELBO, distortion, rate, and log-likelihood + using all data points in a data loader. + """ + D = 0 + R = 0 + L = 0 + data_size = 0 + with torch.no_grad(): + for batch_x, batch_y in dl: + Dx, Rx, Lx = model.DRL(batch_x.to(device), sample_size=sample_size) + D = D + Dx.sum(0) + R = R + Rx.sum(0) + L = L + Lx.sum(0) + data_size += batch_x.shape[0] + D = D / data_size + R = R / data_size + L = L / data_size + return {'ELBO': (-D -R).item(), 'D': D.item(), 'R': R.item(), 'L': L.item()} + + +def train_vae(model: NVIL, opts: OptCollection, + training_data, dev_data, + batch_size=64, num_epochs=10, check_every=10, + sample_size_training=1, + sample_size_eval=10, + grad_clip=5., + num_workers=2, + device=torch.device('cuda:0') + ): + """ + model: pytorch model + optimiser: pytorch optimiser + training_corpus: a TaggedCorpus for trianing + dev_corpus: a TaggedCorpus for dev + batch_size: use more if you have more memory + num_epochs: use more for improved convergence + check_every: use less to check performance on dev set more often + device: where we run the experiment + + Return a log of quantities computed during training (for plotting) + """ + batcher = DataLoader(training_data, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) + dev_batcher = DataLoader(dev_data, batch_size, num_workers=num_workers, pin_memory=True) + + total_steps = num_epochs * len(batcher) + log = defaultdict(list) + + step = 0 + model.eval() + for k, v in assess(model, sample_size_eval, dev_batcher, device=device).items(): + log[f"dev.{k}"].append((step, v)) + + with tqdm(range(total_steps)) as bar: + for epoch in range(num_epochs): + for batch_x, batch_y in batcher: + model.train() + opts.zero_grad() + + loss_dict = model( + batch_x.to(device), + sample_size=sample_size_training, + ) + for metric, value in loss_dict.items(): + log[f'training.{metric}'].append((step, value)) + + loss_dict['loss'].backward() + + nn.utils.clip_grad_norm_( + model.parameters(), + grad_clip + ) + opts.step() + + bar_dict = OrderedDict() + for metric, value in loss_dict.items(): + bar_dict[f'training.{metric}'] = f"{loss_dict[metric]:.2f}" + for metric in ['ELBO', 'D', 'R', 'L']: + bar_dict[f"dev.{metric}"] = "{:.2f}".format(log[f"dev.{metric}"][-1][1]) + bar.set_postfix(bar_dict) + bar.update() + + if step % check_every == 0: + model.eval() + for k, v in assess(model, sample_size_eval, dev_batcher, device=device).items(): + log[f"dev.{k}"].append((step, v)) + + step += 1 + + + model.eval() + for k, v in assess(model, sample_size_eval, dev_batcher, device=device).items(): + log[f"dev.{k}"].append((step, v)) + + return log + +# And, finally, some code to help inspect samples + +def inspect_lvm(model, dl, device): + for x, y in dl: + + x_ = model.sample(16, 4, oversample=True).cpu().reshape(-1, 1, 64, 64) + plt.figure(figsize=(16,8)) + plt.axis('off') + plt.imshow(make_grid(x_, nrow=16).permute((1, 2, 0))) + plt.title("Prior samples") + plt.show() + + plt.figure(figsize=(16,8)) + plt.axis('off') + plt.imshow(make_grid(x, nrow=16).permute((1, 2, 0))) + plt.title("Observations") + plt.show() + + x_ = model.cond_sample(x.to(device)).cpu().reshape(-1, 1, 64, 64) + plt.figure(figsize=(16,8)) + plt.axis('off') + plt.imshow(make_grid(x_, nrow=16).permute((1, 2, 0))) + plt.title("Conditional samples") + plt.show() + + break + +# Variance reduction +# Here are some concrete strategies for variance reduction. You can skip those in a first pass. + +class CentredReward(VarianceReduction): + """ + This control variate does not have trainable parameters, + it maintains a running estimate of the average reward and updates + a batch of rewards by computing reward - avg. + """ + + def __init__(self, alpha=0.9): + super().__init__() + self._alpha = alpha + self._r_mean = 0. + + def forward(self, r, x=None, q=None, r_fn=None): + """ + Centre the reward and update running estimates of mean. + """ + with torch.no_grad(): + # sufficient statistics for next updates + r_mean = torch.mean(r, dim=0) + # centre the signal + r = r - self._r_mean + # update running estimate of mean + self._r_mean = (1-self._alpha) * self._r_mean + self._alpha * r_mean.item() + return r, torch.zeros_like(r) + + +class ScaledReward(VarianceReduction): + """ + This control variate does not have trainable parameters, + it maintains a running estimate of the reward's standard deviation and + updates a batch of rewards by computing reward / maximum(stddev, 1). + """ + + def __init__(self, alpha=0.9): + super().__init__() + self._alpha = alpha + self._r_std = 1.0 + + def forward(self, r, x=None, q=None, r_fn=None): + """ + Scale the reward by a running estimate of std, and also update the estimate. + """ + with torch.no_grad(): + # sufficient statistics for next updates + r_std = torch.std(r, dim=0) + # standardise the signal + r = r / self._r_std + # update running estimate of std + self._r_std = (1-self._alpha) * self._r_std + self._alpha * r_std.item() + # it's not safe to standardise with scales less than 1 + self._r_std = np.maximum(self._r_std, 1.) + return r, torch.zeros_like(r) + + +class SelfCritic(VarianceReduction): + """ + This control variate does not have trainable parameters, + it updates a batch of rewards by computing reward - reward', where + reward' is (log p(X=x|Z=z')).detach() assessed for a novel sample + z' ~ Z|X=x. + """ + + def __init__(self): + super().__init__() + + def forward(self, r, x, q, r_fn): + """ + Standardise the reward and update running estimates of mean/std. + """ + with torch.no_grad(): + z = q.sample() + r = r - r_fn(z, x) + return r, torch.zeros_like(r) + + +class Baseline(VarianceReduction): + """ + An input-dependent baseline implemented as an MLP. + The trainable parameters are adjusted via MSE. + """ + + def __init__(self, num_inputs, hidden_size, p_drop=0.): + super().__init__() + self.baseline = nn.Sequential( + FlattenImage(), + nn.Dropout(p_drop), + nn.Linear(num_inputs, hidden_size), + nn.ReLU(), + nn.Dropout(p_drop), + nn.Linear(hidden_size, 1) + ) + + def forward(self, r, x, q=None, r_fn=None): + """ + Return r - baseline(x) and Baseline's loss. + """ + # batch_shape + (1,) + r_hat = self.baseline(x) + # batch_shape + r_hat = r_hat.squeeze(-1) + loss = (r - r_hat)**2 + return r - r_hat.detach(), loss + + +class CVChain(VarianceReduction): + + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def forward(self, r, x, q, r_fn): + loss = 0 + for cv in self._modules.values(): + r, l = cv(r, x=x, q=q, r_fn=r_fn) + loss = loss + l + return r, loss + +seed_all() +my_device = torch.device('cuda:0') + + +model = NVIL( + JointDistribution( + MoGPriorNet(128, 128), + MoGCPDNet( + outcome_shape=img_shape, + num_inputs=batch_size, + hidden_size=32, + num_components=2 + ) + ), + InferenceModel( + latent_size=128, + num_channels=img_shape[0], + width=img_shape[1], + height=img_shape[2], + cpd_net_type=partial(MoGCPDNet) # Gaussian prior and Gaussian posterior: this is a classic VAE + ), + # VarianceReduction(), # no variance reduction is needed for a VAE + CVChain( # variance reduction helps SFE + CentredReward(), + Baseline(np.prod(img_shape), 512), # this is how you would use a trained baselined + ScaledReward() + ) +).to(my_device) + +opts = OptCollection( + # Tips based on empirical practice: + + # Adam is the go-to choice for (reparameterised) VAEs + opt.Adam(model.gen_params(), lr=5e-4, weight_decay=1e-6), + opt.Adam(model.inf_params(), lr=1e-4), + + + # Adam is not often a good choice for SFE-based optimisation + # a possible reason: SFE is too noisy and the design choices behind Adam + # were made having reparameterised gradients in mind + #opt.RMSprop(model.gen_params(), lr=5e-4, weight_decay=1e-6), + #opt.RMSprop(model.inf_params(), lr=1e-4), + #opt.RMSprop(model.cv_params(), lr=1e-4, weight_decay=1e-6) # you need this if your baseline has trainable parameters +) + +model + + +log = train_vae( + model=model, + opts=opts, + training_data=combined_train_data, + dev_data=combined_val_data, + batch_size=64, + num_epochs=100, # use more for better models + check_every=100, + sample_size_training=1, + sample_size_eval=1, + grad_clip=5., + device=my_device +) + +log.keys() + +fig, axs = plt.subplots(1, 3 + int('training.cv_loss' in log), sharex=True, sharey=False, figsize=(20, 5)) + +_ = axs[0].plot(np.array(log['training.ELBO'])[:,0], np.array(log['training.ELBO'])[:,1]) +_ = axs[0].set_ylabel("training ELBO") +_ = axs[0].set_xlabel("steps") + +_ = axs[1].plot(np.array(log['training.D'])[:,0], np.array(log['training.D'])[:,1]) +_ = axs[1].set_ylabel("training D") +_ = axs[1].set_xlabel("steps") + +_ = axs[2].plot(np.array(log['training.R'])[:,0], np.array(log['training.R'])[:,1]) +_ = axs[2].set_ylabel("training R") +_ = axs[2].set_xlabel("steps") + +if 'training.cv_loss' in log: + _ = axs[3].plot(np.array(log['training.cv_loss'])[:,0], np.array(log['training.cv_loss'])[:,1]) + _ = axs[3].set_ylabel("cv loss") + _ = axs[3].set_xlabel("steps") + +fig.tight_layout(h_pad=2, w_pad=2) + + + +fig, axs = plt.subplots(1, 4, sharex=True, sharey=False, figsize=(20, 5)) + +_ = axs[0].plot(np.array(log['dev.ELBO'])[:,0], np.array(log['dev.ELBO'])[:,1]) +_ = axs[0].set_ylabel("dev ELBO") +_ = axs[0].set_xlabel("steps") + +_ = axs[1].plot(np.array(log['dev.D'])[:,0], np.array(log['dev.D'])[:,1]) +_ = axs[1].set_ylabel("dev D") +_ = axs[1].set_xlabel("steps") + +_ = axs[2].plot(np.array(log['dev.R'])[:,0], np.array(log['dev.R'])[:,1]) +_ = axs[2].set_ylabel("dev R") +_ = axs[2].set_xlabel("steps") + +_ = axs[3].plot(np.array(log['dev.L'])[:,0], np.array(log['dev.L'])[:,1]) +_ = axs[3].set_ylabel("dev L") +_ = axs[3].set_xlabel("steps") + +fig.tight_layout(h_pad=2, w_pad=2) + +inspect_lvm(model, DataLoader(combined_val_data, 128, num_workers=2, pin_memory=True), my_device) + + + +# I generate three different sets of data (these sets have segments), one set is the training data, the other set is a set of validation data and the last set is the test data. The objectives are using the codes that I did before, review those codes and edit those to developt the learning using train, validation and test. The training set has segments with labels (0, 1, 2, 3) and segments unlabeled (-1), implementing a semisupervised latent code for this (I tried a class but I think this need some changes but my knowledge is limited.) + +# Assuming x_labeled, y_labeled, x_unlabeled are your labeled and unlabeled datasets +# Set up your NVI model + + +# Set up your optimizer +gen_optimizer = opt.Adam(model.gen_params(), lr=5e-4, weight_decay=1e-6) +inf_optimizer = opt.Adam(model.inf_params(), lr=1e-4) +cv_optimizer = opt.Adam(model.cv_params(), lr=1e-4, weight_decay=1e-6) + +opts = OptCollection(gen_optimizer, inf_optimizer, cv_optimizer) + + +def train_semisupervised_vae(model: NVIL, opts: OptCollection, train_loader, val_loader, test_loader, + num_epochs=10, check_every=10, + sample_size_training=1, + sample_size_eval=10, + grad_clip=5., + device=torch.device('cuda:0') + ): + """ + Train a semi-supervised NVIL model. + + model: NVIL model + opts: OptCollection containing optimizers + train_loader: DataLoader for combined labeled and unlabeled data + val_loader: DataLoader for validation data + test_loader: DataLoader for test data + num_epochs: number of training epochs + check_every: frequency to check performance on dev set + sample_size_training: number of samples for training + sample_size_eval: number of samples for evaluation + grad_clip: gradient clipping value + device: device for training + + Returns a log of quantities computed during training (for plotting) + """ + + total_steps = num_epochs * len(train_loader) + log = defaultdict(list) + + step = 0 + model.eval() + for k, v in assess(model, sample_size_eval, val_loader, device=device).items(): + log[f"dev.{k}"].append((step, v)) + + with tqdm(range(total_steps)) as bar: + for epoch in range(num_epochs): + for batch_x, batch_y in train_loader: + model.train() + opts.zero_grad() + + # Check if the data is labeled or unlabeled + labeled_mask = (batch_y != -1).to(device) + + # Forward pass for unlabeled data + unlabeled_loss_dict = model( + batch_x.to(device), + sample_size=sample_size_training, + ) + + # Forward pass for labeled data with classification loss + labeled_loss_dict = model( + batch_x[labeled_mask].to(device), + sample_size=sample_size_training, + ) + + # Classification loss only for labeled data + classification_loss = F.cross_entropy( + labeled_loss_dict['logits'].squeeze(), batch_y[labeled_mask].to(device) + ) + + # Combine losses + total_loss = unlabeled_loss_dict['loss'] + classification_loss + + total_loss.backward() + + nn.utils.clip_grad_norm_( + model.parameters(), + grad_clip + ) + opts.step() + + bar_dict = OrderedDict() + bar_dict['training.loss'] = f"{total_loss:.2f}" + for metric, value in unlabeled_loss_dict.items(): + bar_dict[f'training.{metric}'] = f"{unlabeled_loss_dict[metric]:.2f}" + for metric in ['ELBO', 'D', 'R', 'L']: + bar_dict[f"dev.{metric}"] = "{:.2f}".format(log[f"dev.{metric}"][-1][1]) + bar.set_postfix(bar_dict) + bar.update() + + if step % check_every == 0: + model.eval() + for k, v in assess(model, sample_size_eval, val_loader, device=device).items(): + log[f"dev.{k}"].append((step, v)) + + step += 1 + + model.eval() + for k, v in assess(model, sample_size_eval, test_loader, device=device).items(): + log[f"test.{k}"].append((step, v)) + + return log + +# Assuming you have the NVIL model (nvil), optimizers (opts), and data loaders (train_loader, val_loader, test_loader) +log = train_semisupervised_vae(nvil, opts, train_loader, val_loader, test_loader)