diff --git a/project_1.py b/project_1.py index cfe9c9b..0acb00b 100644 --- a/project_1.py +++ b/project_1.py @@ -21,6 +21,19 @@ import seaborn as sns from PIL import Image # Import the Image module +# Create a logger +logger = logging.getLogger(__name__) +logging.basicConfig(filename='error_log.txt', level=logging.ERROR) +logging.basicConfig(level=logging.INFO) + +# Create a logger for progress monitoring +progress_logger = logging.getLogger('progress') +progress_logger.setLevel(logging.INFO) +progress_handler = logging.FileHandler('progress.log') +progress_formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s') +progress_handler.setFormatter(progress_formatter) +progress_logger.addHandler(progress_handler) + def get_data_paths(data_format, is_linux=False, is_hpc=False): if is_linux: base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch" @@ -48,10 +61,6 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False): return data_path, labels_path, saving_path -# Create a logger -logger = logging.getLogger(__name) -logging.basicConfig(filename='error_log.txt', level=logging.ERROR) - # Standardize the data def standard_scaling(data): scaler = StandardScaler() @@ -71,7 +80,6 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr X_data = [] # Store all data X_data_original = [] # Store original data without standardization segment_names = [] - validated_labels = [] # Store only the label values for UID in UID_list: data_path_UID = os.path.join(data_path, UID) @@ -98,6 +106,7 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr except Exception as e: logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}") + logger.error(f"Error processing segment: {time_freq_plot.size()} in UID: {UID}. Exception: {str(e)}") # You can also add more information to the error log, such as the value of time_freq_plot. continue # Continue to the next segment @@ -106,6 +115,7 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr if standardize: X_data = standard_scaling(X_data) # Standardize the data + # Extract labels from CSV files labels = extract_labels(UID_list, labels_path, segment_names) @@ -129,114 +139,22 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr # Set labels that are not in the important list as -1 (Unlabeled) segment_labels[segment_name] = -1 - # Return all segments along with labels - if return_all: - return X_data_original, X_data, segment_names, segment_labels, segment_labels.values() - # Filter out segments that are unlabeled (-1) filtered_segment_names = [segment_name for segment_name, label in segment_labels.items() if label != -1] # Filter data to match the filtered segment names filtered_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in filtered_segment_names]) - # Return labeled and unlabeled segments along with labels - if return_all == 'labeled': - return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values() - - # Return unlabeled segments along with labels - if return_all == 'unlabeled': - unlabeled_segment_names = [segment_name for segment_name, label in segment_labels.items() if label == -1] - unlabeled_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in unlabeled_segment_names]) - return X_data_original, unlabeled_data, unlabeled_segment_names, {seg: segment_labels[seg] for seg in unlabeled_segment_names}, {seg: segment_labels[seg] for seg in unlabeled_segment_names}.values() - - # By default, return only labeled segments along with labels - return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values() - -def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=True, data_format='csv', return_all=False): - if data_format not in ['csv', 'png']: - raise ValueError("Invalid data_format. Choose 'csv' or 'png.") - - dir_list_UID = os.listdir(data_path) - UID_list = dir_list_UID[:dataset_size] if train else dir_list_UID[dataset_size:] - - X_data = [] # Store all data - X_data_original = [] # Store original data without standardization - segment_names = [] - - for UID in UID_list: - data_path_UID = os.path.join(data_path, UID) - dir_list_seg = os.listdir(data_path_UID) - - for seg in dir_list_seg[:len(dir_list_seg)]: # Limiting to 50 segments - seg_path = os.path.join(data_path_UID, seg) - - try: - if data_format == 'csv' and seg.endswith('.csv'): - time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) - time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) - elif data_format == 'png' and seg.endswith('.png'): - img = Image.open(seg_path) - img_data = np.array(img) - time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) - else: - continue # Skip other file formats - - X_data.append(time_freq_tensor) - X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data - - segment_names.append(seg.split('_filt')[0]) # Extract and store segment names - - except Exception as e: - logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}") - logger.error(f"Error processing segment: {time_freq_plot.size()} in UID: {UID}. Exception: {str(e)}") - # You can also add more information to the error log, such as the value of time_freq_plot. - continue # Continue to the next segment - - X_data = torch.cat(X_data, 0) - X_data_original = torch.cat(X_data_original, 0) - - if standardize: - X_data = standard_scaling(X_data) # Standardize the data - - # Extract labels from CSV files - labels = extract_labels(UID_list, labels_path, segment_names) - - important_labels = [0.0, 1.0, 2.0, 3.0] # List of important labels - - # Initialize labels for segments as unlabeled (-1) - segment_labels = {segment_name: -1 for segment_name in segment_names} - - for UID in labels.keys(): - if UID not in UID_list: - # Skip UIDs that are not in the dataset - continue - - label_data, label_segment_names = labels[UID] - - for idx, segment_label in enumerate(label_data): - segment_name = label_segment_names[idx] - if segment_label in important_labels: - segment_labels[segment_name] = segment_label - else: - # Set labels that are not in the important list as -1 (Unlabeled) - segment_labels[segment_name] = -1 - # Return all segments along with labels - if return_all: + if return_all is True: return X_data_original, X_data, segment_names, segment_labels, segment_labels.values() - # Filter out segments that are unlabeled (-1) - filtered_segment_names = [segment_name for segment_name, label in segment_labels.items() if label != -1] - - # Filter data to match the filtered segment names - filtered_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in filtered_segment_names]) - # Return labeled and unlabeled segments along with labels if return_all == 'labeled': return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values() # Return unlabeled segments along with labels - if return_all == 'unlabeled': + elif return_all == 'unlabeled': unlabeled_segment_names = [segment_name for segment_name, label in segment_labels.items() if label == -1] unlabeled_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in unlabeled_segment_names]) return X_data_original, unlabeled_data, unlabeled_segment_names, {seg: segment_labels[seg] for seg in unlabeled_segment_names}, {seg: segment_labels[seg] for seg in unlabeled_segment_names}.values() @@ -244,8 +162,6 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr # By default, return only labeled segments along with labels return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values() - - def extract_labels(UID_list, labels_path, segment_names): labels = {} for UID in UID_list: @@ -299,17 +215,17 @@ def visualize_trends(standardized_data, original_data, segment_names, num_plots= else: print("This is a trend analysis for data in an unsupported format.") -def perform_pca(data, num_components=2, num_clusters=4): - # Perform PCA for dimensionality reduction - data_flattened = data.view(data.size(0), -1) # Flatten the data - pca = PCA(n_components=num_components) - reduced_data = pca.fit_transform(data_flattened.numpy()) +# def perform_pca(data, num_components=2, num_clusters=4): +# # Perform PCA for dimensionality reduction +# data_flattened = data.view(data.size(0), -1) # Flatten the data +# pca = PCA(n_components=num_components) +# reduced_data = pca.fit_transform(data_flattened.numpy()) - # Cluster the data using K-Means - kmeans = KMeans(n_clusters=num_clusters) - labels = kmeans.fit_predict(reduced_data) +# # Cluster the data using K-Means +# kmeans = KMeans(n_clusters=num_clusters) +# labels = kmeans.fit_predict(reduced_data) - return reduced_data, pca, labels +# return reduced_data, pca, labels def perform_pca_sgd(data, num_components=2, num_clusters=4, batch_size=64): data_flattened = data.view(data.size(0), -1) @@ -336,56 +252,75 @@ def perform_tsne(data, num_components=2, num_clusters=4): return reduced_data, labels -def visualize_correlation_matrix(data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=None): - ''' - Usage: - To visualize the correlation matrix for each subject individually, you can call: - visualize_correlation_matrix(train_data, subject_mode=True, save_path="path_to_save_results") - - To visualize the correlation matrix for a specific quantity of subjects (groups), you can call: - visualize_correlation_matrix(train_data, subject_mode=False, num_subjects_to_visualize=5, save_path="path_to_save_results") - ''' - # Visualize the correlation matrix for each subject or subgroup +def visualize_correlation_matrix(data, segment_names, subject_mode=True, num_subjects_to_visualize=None, batch_size=32, method='pearson', save_path=None): data_flattened = data.view(data.size(0), -1).numpy() - subject_names = [filename.split('_')[0] for filename in segment_names] - unique_subjects = list(set(subject_names)) - if subject_mode: - for subject in unique_subjects: - subject_indices = [i for i, name in enumerate(subject_names) if name == subject] - subject_data = data_flattened[subject_indices] - correlation_matrix = np.corrcoef(subject_data, rowvar=False) - - plt.figure() - sns.heatmap(correlation_matrix, cmap="coolwarm", xticklabels=False, yticklabels=False) - plt.title(f"Correlation Matrix for Subject {subject}") - - if save_path: - subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_{subject}.png") - plt.savefig(subject_save_path) - - plt.show() - - else: # Group mode - if num_subjects_to_visualize is None: - num_subjects_to_visualize = len(unique_subjects) + subject_names = [filename.split('_')[0] for filename in segment_names] + unique_subjects = list(set(subject_names)) + else: + subject_names = [filename.split('_')[0] for filename in segment_names] + unique_subjects = list(set(subject_names)) + + if num_subjects_to_visualize is None: + num_subjects_to_visualize = len(unique_subjects) + + for i in range(num_subjects_to_visualize): + subject = unique_subjects[i] + subject_indices = [j for j, name in enumerate(subject_names) if name == subject] + subject_data = data_flattened[subject_indices] - for i in range(num_subjects_to_visualize): - subject = unique_subjects[i] - subject_indices = [i for i, name in enumerate(subject_names) if name == subject] - subject_data = data_flattened[subject_indices] - correlation_matrix = np.corrcoef(subject_data, rowvar=False) - + # Shuffle the data to avoid bias + np.random.shuffle(subject_data) + + # Calculate the number of batches + num_batches = len(subject_data) // batch_size + + batch_correlations = [] + + for batch_index in range(num_batches): + start = batch_index * batch_size + end = (batch_index + 1) * batch_size + batch = subject_data[start:end] + + # Calculate the correlation matrix for the batch + correlation_matrix = np.corrcoef(batch, rowvar=False) + + # Calculate the mean or median of the per-batch correlations + if method == 'mean': + batch_correlation = np.mean(correlation_matrix) + elif method == 'median': + batch_correlation = np.median(correlation_matrix) + + batch_correlations.append(batch_correlation) + + # Aggregate the batch correlations + overall_correlation = np.mean(batch_correlations) # You can use median instead of mean if needed + + # Calculate confidence intervals on the aggregated correlation + batch_correlations = np.array(batch_correlations) + ci_lower = np.percentile(batch_correlations, 2.5) + ci_upper = np.percentile(batch_correlations, 97.5) + + # Print or save the results + print(f"Overall Correlation for {num_subjects_to_visualize} Subjects {subject}: {overall_correlation:.4f}") + print(f"Confidence Intervals: [{ci_lower:.4f}, {ci_upper:.4f}]") + + if save_path: + subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_group_{subject}.png") plt.figure() sns.heatmap(correlation_matrix, cmap="coolwarm", xticklabels=False, yticklabels=False) plt.title(f"Correlation Matrix for {num_subjects_to_visualize} Subjects {subject}") - - if save_path: - subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_group_{subject}.png") - plt.savefig(subject_save_path) - - plt.show() + plt.savefig(subject_save_path) + plt.close() + + # Plot the per-batch correlations over time/batches + plt.figure() + plt.plot(batch_correlations) + plt.xlabel("Batch Index") + plt.ylabel("Correlation") + plt.title(f"Per-Batch Correlations for {num_subjects_to_visualize} Subjects {subject}") + plt.show() # This function computes the log PDF of a multivariate normal distribution def multivariate_normal_log_pdf_MFVI(x, mu, sigma_sq): @@ -469,6 +404,7 @@ def evaluate_clustering(data, true_labels, predicted_labels): print(f'Adjusted Mutual Info (AMI): {ami}') print(f'Silhouette Score: {silhouette}') print(f'Davies-Bouldin Index: {davies_bouldin}') + return ari, ami, silhouette, davies_bouldin def plot_pca(reduced_data, labels, method='original_labels', save_path=None): """ @@ -493,7 +429,8 @@ def plot_pca(reduced_data, labels, method='original_labels', save_path=None): # Save the plot if save_path is provided if save_path: - plt.savefig(save_path, f"PCA_analysis_using_{method}.png") + filename = method.replace(' ', '_') + ".png" + plt.savefig(os.path.join(save_path, filename)) plt.show() # Example usage: @@ -528,129 +465,238 @@ def plot_clusters(data, zi, title, save_path=None): plt.savefig(os.path.join(save_path, filename)) plt.show() - -def main(): - is_linux = False # Set to True if running on Linux, False if on Windows - is_hpc = False # Set to True if running on hpc, False if on Windows + +def perform_dimensionality_reduction(data, data_type, saving_path): + log_progress(f"Performing dimensionality reduction for {data_type} data") + try: + ##### ALL DATA ###### + # Perform PCA on all data + pca_reduced_data, pca, pca_labels= perform_pca_sgd(data, num_components=2, num_clusters=4, batch_size=64) + except Exception as e: + logger.error(f"An error occurred while processing pca in all data: {str(e)}") + + try: + # Perform t-SNE on all data + tsne_reduced_data, tsne_labels = perform_tsne(data) + except Exception as e: + logger.error(f"An error occurred while processing t-SNE in all data: {str(e)}") + + return pca_reduced_data, pca_labels, tsne_reduced_data, tsne_labels + +def perform_clustering(data, labels, pca_reduced_data, pca_labels, tsne_reduced_data, tsne_labels, mfvi_labels, data_type, saving_path): + log_progress(f"Performing clustering for {data_type} data") + + if data_type == "labeled": + method ='PCA on Labeled Data' + title_tsne = 't-SNE on Labeled Data' + title_mfvi = 'MFVI on Labeled Data' + elif data_type == "unlabeled": + method ='PCA on Unlabeled Data' + title_tsne = 't-SNE on Unlabeled Data' + title_mfvi = 'MFVI on Unabeled Data' + elif data_type == "all": + method ='PCA on All Data' + title_tsne = 't-SNE on All Data' + title_mfvi = 'MFVI on All Data' + else: + log_progress("Invalid data_type specified.") + return None # Return early if the data type is not valid. + + # Perform PCA on the data + try: + # Plot PCA for the data + plot_pca(pca_reduced_data, labels, method, save_path=saving_path) + except Exception as e: + logger.error(f"An error occurred while plotting PCA: {str(e)}") + + try: + # Evaluate clustering for PCA results + ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca = evaluate_clustering(data.view(data.size(0), -1).numpy(), labels, pca_labels) + except Exception as e: + logger.error(f"An error occurred while clustering PCA results: {str(e)}") + + # Perform t-SNE on the data + try: + # Plot t-SNE for the data + plot_clusters(tsne_reduced_data, tsne_labels, title_tsne, save_path=saving_path) + except Exception as e: + logger.error(f"An error occurred while plotting t-SNE: {str(e)}") + + try: + # Evaluate clustering for t-SNE results + ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne = evaluate_clustering(tsne_reduced_data, labels, tsne_labels) + except Exception as e: + logger.error(f"An error occurred while clustering t-SNE results: {str(e)}") + + try: + # Plot MFVI for data + plot_clusters(data.view(data.size(0), -1).numpy(), mfvi_labels, title_mfvi, save_path=saving_path) + except Exception as e: + logger.error(f"An error occurred while plotting Mean-Field Variational Inference in labeled data: {str(e)}") + + try: + # For MFVI on data + ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi = evaluate_clustering(data.view(data.size(0), -1).numpy(), labels, mfvi_labels) + except Exception as e: + logger.error(f"An error occurred while clustering Mean-Field Variational Inference results in labeled data: {str(e)}") + + return ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca, ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne, ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi + +def visualize_and_analyze_data(data, original_data, segments, labels, data_type, saving_path, data_format): + log_progress(f"Visualizing and analyzing {data_type} data") + + try: + visualize_trends(data, original_data, segments, num_plots=10, data_format=data_format, save_path=saving_path) + except Exception as e: + handle_error(f"Visualizing trends for {data_type} data", e) + + try: + visualize_correlation_matrix(data, segments, subject_mode=False, num_subjects_to_visualize=None, batch_size=32, method='pearson', save_path=saving_path) + except Exception as e: + handle_error(f"Visualizing correlation matrix for {data_type} data", e) + + try: + pca_reduced_data, pca_labels, tsne_reduced_data, tsne_labels = perform_dimensionality_reduction(data, data_type, saving_path) + except Exception as e: + handle_error(f"Performing dimensionality reduction for {data_type} data", e) + + try: + # Perform MFVI on data + miu, pi, resp = perform_mfvi(data, K=4, n_optimization_iterations=300, convergence_threshold=1e-5, run_until_convergence=False) + + # Extract cluster assignments from MFVI + mfvi_labels = torch.argmax(resp, dim=1).numpy() + except Exception as e: + logger.error(f"An error occurred while Mean-Field Variational Inference in labeled data: {str(e)}") + + try: + ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca, ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne, ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi = perform_clustering(data, labels, pca_reduced_data, pca_labels, tsne_reduced_data, tsne_labels, mfvi_labels, data_type, saving_path) + except Exception as e: + handle_error(f"Performing clustering for {data_type} data", e) + + return ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca, ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne, ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi, miu, pi, resp + +def process_labeled_data(labeled_data, labeled_original_data, labeled_segments, labeled_labels, saving_path, data_format): + log_progress("Label data execution started.") + results = {} # Create a dictionary to store the results. + + try: + ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca, ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne, ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi, miu, pi, resp = visualize_and_analyze_data(labeled_data, labeled_original_data, labeled_segments, labeled_labels, "labeled", saving_path, data_format) + # Store the results in the dictionary + results['ari_pca'] = ari_pca + results['ami_pca'] = ami_pca + results['silhouette_pca'] = silhouette_pca + results['davies_bouldin_pca'] = davies_bouldin_pca + results['ari_tsne'] = ari_tsne + results['ami_tsne'] = ami_tsne + results['silhouette_tsne'] = silhouette_tsne + results['davies_bouldin_tsne'] = davies_bouldin_tsne + results['ari_mfvi'] = ari_mfvi + results['ami_mfvi'] = ami_mfvi + results['silhouette_mfvi'] = silhouette_mfvi + results['davies_bouldin_mfvi'] = davies_bouldin_mfvi + results['miu'] = miu + results['pi'] = pi + results['resp'] = resp + except Exception as e: + logger.error(f"An error occurred during data processing: {str(e)}") + + log_progress("Label data execution completed.") + return ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca, ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne, ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi, miu, pi, resp + +def process_unlabeled_data(unlabeled_data, unlabeled_original_data, unlabeled_segments, unlabeled_labels, saving_path, data_format): + log_progress("Unlabel data execution started.") + results = {} # Create a dictionary to store the results. + + try: + ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca, ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne, ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi, miu, pi, resp = visualize_and_analyze_data(unlabeled_data, unlabeled_original_data, unlabeled_segments, unlabeled_labels, "unlabeled", saving_path, data_format) + # Store the results in the dictionary + results['ari_pca'] = ari_pca + results['ami_pca'] = ami_pca + results['silhouette_pca'] = silhouette_pca + results['davies_bouldin_pca'] = davies_bouldin_pca + results['ari_tsne'] = ari_tsne + results['ami_tsne'] = ami_tsne + results['silhouette_tsne'] = silhouette_tsne + results['davies_bouldin_tsne'] = davies_bouldin_tsne + results['ari_mfvi'] = ari_mfvi + results['ami_mfvi'] = ami_mfvi + results['silhouette_mfvi'] = silhouette_mfvi + results['davies_bouldin_mfvi'] = davies_bouldin_mfvi + results['miu'] = miu + results['pi'] = pi + results['resp'] = resp + + except Exception as e: + logger.error(f"An error occurred during data processing: {str(e)}") + + log_progress("Unlabel data execution completed.") + return results + +def process_all_data(all_data, all_original_data, all_segment_names, all_labels, saving_path, data_format): + log_progress("All data execution started.") + results = {} # Create a dictionary to store the results. + + try: + ari_pca, ami_pca, silhouette_pca, davies_bouldin_pca, ari_tsne, ami_tsne, silhouette_tsne, davies_bouldin_tsne, ari_mfvi, ami_mfvi, silhouette_mfvi, davies_bouldin_mfvi, miu, pi, resp = visualize_and_analyze_data(all_data, all_original_data, all_segment_names, all_labels, "all", saving_path, data_format) + # Store the results in the dictionary + results['ari_pca'] = ari_pca + results['ami_pca'] = ami_pca + results['silhouette_pca'] = silhouette_pca + results['davies_bouldin_pca'] = davies_bouldin_pca + results['ari_tsne'] = ari_tsne + results['ami_tsne'] = ami_tsne + results['silhouette_tsne'] = silhouette_tsne + results['davies_bouldin_tsne'] = davies_bouldin_tsne + results['ari_mfvi'] = ari_mfvi + results['ami_mfvi'] = ami_mfvi + results['silhouette_mfvi'] = silhouette_mfvi + results['davies_bouldin_mfvi'] = davies_bouldin_mfvi + results['miu'] = miu + results['pi'] = pi + results['resp'] = resp + except Exception as e: + logger.error(f"An error occurred during data processing: {str(e)}") + + log_progress("All data execution completed.") + return results - data_format = 'csv' # Choose 'csv' or 'png' +def log_progress(message): + progress_logger.info(message) +def main(case_to_run): + log_progress("Code execution started.") + is_linux = False # Set to True if running on Linux, False if on Windows + is_hpc = False # Set to True if running on an HPC, False if on Windows + data_format = 'csv' # Choose 'csv' or 'png' data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) - - # Load data with labels and segment names - _, labeled_data, _, _, labeled_labels = load_data(data_path, labels_path, dataset_size=141, train=True, data_format=data_format, return_all="labeled") - - # Load unlabeled data - _, unlabeled_data, _, _, unlabeled_labels = load_data(data_path, labels_path, dataset_size=141, train=True, data_format=data_format, return_all="unlabeled") - - # Load all data (labeled and unlabeled) - original_data, all_data, segment_names, segment_labels, all_labels = load_data(data_path, labels_path, dataset_size=141, train=True, data_format=data_format, return_all=True) - - # test_data, _, _, _, _ = load_data(data_path, labels_path, dataset_size=30, train=False) - train_dataloader = create_dataloader(labeled_data) - # test_dataloader = create_dataloader(test_data) + results = {} # Create a dictionary to store results. - # Visualize random trends/segments - visualize_trends(labeled_data, original_data, segment_names, num_plots=10, data_format=data_format, save_path=saving_path) + try: + if case_to_run == "labeled": + labeled_original_data, labeled_data, labeled_segments, labeled_segments_labels, labeled_labels = load_data(data_path, labels_path, dataset_size=15, train=True, data_format=data_format, return_all="labeled") + results['labeled'] = process_labeled_data(labeled_data, labeled_original_data, labeled_segments, labeled_labels, saving_path, data_format) + elif case_to_run == "unlabeled": + unlabeled_original_data, unlabeled_data, unlabeled_segments, unlabeled_segments_labels, unlabeled_labels = load_data(data_path, labels_path, dataset_size=10, train=True, data_format=data_format, return_all="unlabeled") + results['unlabeled'] = process_unlabeled_data(unlabeled_data, unlabeled_original_data, unlabeled_segments, unlabeled_labels, saving_path, data_format) + elif case_to_run == "all_data": + all_original_data, all_data, all_segment_names, segment_labels, all_labels = load_data(data_path, labels_path, dataset_size=10, train=True, data_format=data_format, return_all=True) + results['all_data'] = process_all_data(all_data, all_original_data, all_segment_names, all_labels, saving_path, data_format) + else: + log_progress("Invalid case specified. Please use 'labeled', 'unlabeled', or 'all_data'.") - # Visualize the correlation matrix - # visualize_correlation_matrix(labeled_data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=saving_path) - # visualize_correlation_matrix(unlabeled_data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=saving_path) - # visualize_correlation_matrix(all_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path) + except Exception as e: + handle_error(f"{case_to_run} data", e) - ##### LABELED ###### - # Perform PCA on labeled data - pca_reduced_data, pca, pca_labeled_labels= perform_pca_sgd(labeled_data, num_components=2, num_clusters=4, batch_size=64) - - # Plot PCA for labeled data - plot_pca(pca_reduced_data, labeled_labels, method='PCA on Labeled Data', save_path=saving_path) + log_progress("Code execution completed.") - # For PCA on labeled data - evaluate_clustering(labeled_data.view(labeled_data.size(0), -1).numpy(), labeled_labels, pca_labeled_labels) + return results # Return the results dictionary - # Perform t-SNE on labeled data - tsne_reduced_data, tsne_labels = perform_tsne(labeled_data) - - # Plot t-SNE for labeled data - plot_clusters(tsne_reduced_data, tsne_labels, 't-SNE on Labeled Data', save_path=saving_path) - - # For t-SNE on labeled data - evaluate_clustering(tsne_reduced_data, labeled_labels, tsne_labels) - - # Perform MFVI on labeled data - miu, pi, resp = perform_mfvi(labeled_data, K=4, n_optimization_iterations=300, convergence_threshold=1e-5, run_until_convergence=False) - - # Extract cluster assignments from MFVI - mfvi_labels = torch.argmax(resp, dim=1).numpy() - - # Plot MFVI for labeled data - plot_clusters(labeled_data.view(labeled_data.size(0), -1).numpy(), mfvi_labels, 'MFVI on Labeled Data', save_path=saving_path) - - # For MFVI on labeled data - evaluate_clustering(labeled_data.view(labeled_data.size(0), -1).numpy(), labeled_labels, mfvi_labels) - - ##### UNLABELED ###### - # Perform PCA on unlabeled data - pca_reduced_data, pca, pca_unlabeled_labels = perform_pca_sgd(unlabeled_data, num_components=2, num_clusters=4, batch_size=64) - - # Plot PCA for unlabeled data - plot_pca(pca_reduced_data, unlabeled_labels, method='PCA on Unlabeled Data', save_path=saving_path) - - # For PCA on unlabeled data - evaluate_clustering(unlabeled_data.view(unlabeled_data.size(0), -1).numpy(), unlabeled_labels, pca_unlabeled_labels) - - # Perform t-SNE on unlabeled data - tsne_reduced_data, tsne_labels = perform_tsne(unlabeled_data) - - # Plot t-SNE for unlabeled data - plot_clusters(tsne_reduced_data, tsne_labels, 't-SNE on Unlabeled Data', save_path=saving_path) - - # For t-SNE on unlabeled data - evaluate_clustering(tsne_reduced_data, unlabeled_labels, tsne_labels) - - # Perform MFVI on unlabeled data - miu, pi, resp = perform_mfvi(unlabeled_data, K=4, n_optimization_iterations=300, convergence_threshold=1e-5, run_until_convergence=False) - - # Extract cluster assignments from MFVI - mfvi_labels = torch.argmax(resp, dim=1).numpy() - - # Plot MFVI for unlabeled data - plot_clusters(unlabeled_data.view(unlabeled_data.size(0), -1).numpy(), mfvi_labels, 'MFVI on Unlabeled Data', save_path=saving_path) - - # For MFVI on unlabeled data - evaluate_clustering(unlabeled_data.view(unlabeled_data.size(0), -1).numpy(), unlabeled_labels, mfvi_labels) - - ##### ALL DATA ###### - # Perform PCA on all data - pca_reduced_data, pca, pca_labels= perform_pca_sgd(all_data, num_components=2, num_clusters=4, batch_size=64) - - # Plot PCA for all data - plot_pca(pca_reduced_data, all_labels, method='PCA on All Data', save_path=saving_path) - - # For PCA on all data - evaluate_clustering(all_data.view(all_data.size(0), -1).numpy(), all_labels, pca_labels) - - # Perform t-SNE on all data - tsne_reduced_data, tsne_labels = perform_tsne(all_data) - - # Plot t-SNE for all data - plot_clusters(tsne_reduced_data, tsne_labels, 't-SNE on All Data', save_path=saving_path) - - # For t-SNE on all data - evaluate_clustering(tsne_reduced_data, all_labels, tsne_labels) - - # Perform MFVI on all data - miu, pi, resp = perform_mfvi(all_data, K=4, n_optimization_iterations=300, convergence_threshold=1e-5, run_until_convergence=False) - - # Extract cluster assignments from MFVI - mfvi_labels = torch.argmax(resp, dim=1).numpy() - - # Plot MFVI for all data - plot_clusters(all_data.view(all_data.size(0), -1).numpy(), mfvi_labels, 'MFVI on All Data', save_path=saving_path) - - # For MFVI on all data - evaluate_clustering(all_data.view(all_data.size(0), -1).numpy(), all_labels, mfvi_labels) +def handle_error(task, error): + logger.error(f"An error occurred while {task}: {str(error)}") if __name__ == "__main__": - main() \ No newline at end of file + # Specify the case you want to run: 'labeled', 'unlabeled', or 'all_data' + case_to_run = "labeled" + main(case_to_run) \ No newline at end of file