diff --git a/project_1.py b/project_1.py index 90d6e78..caf0e18 100644 --- a/project_1.py +++ b/project_1.py @@ -105,7 +105,7 @@ def create_dataloader(data, batch_size=64, shuffle=True): data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) return data_loader -def visualize_trends(data, segment_names, num_plots=3): +def visualize_trends(data, segment_names, num_plots=3, save_path=None): # Visualize random trends/segments num_samples, _, _ = data.shape for _ in range(num_plots): @@ -114,6 +114,9 @@ def visualize_trends(data, segment_names, num_plots=3): plt.imshow(data[idx].numpy()) plt.title(f"Segment: {segment_names[idx]}") plt.colorbar() + if save_path: + subject_save_path = os.path.join(save_path, f"trends_visualization_segment_{segment_names}.png") + plt.savefig(subject_save_path) plt.show() def perform_pca(data, num_components=2): @@ -169,7 +172,7 @@ def visualize_correlation_matrix(data, segment_names, subject_mode=True, num_sub plt.title(f"Correlation Matrix for {num_subjects_to_visualize} Subjects {subject}") if save_path: - subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_group{subject}.png") + subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_group_{subject}.png") plt.savefig(subject_save_path) plt.show() @@ -312,7 +315,7 @@ def main(): data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) - train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=20, train=True, data_format=data_format) + train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=10, train=True, data_format=data_format) # test_data, _, _ = load_data(data_path, labels_path, dataset_size=30, train=False) train_dataloader = create_dataloader(train_data) @@ -327,12 +330,11 @@ def main(): # Visualize the correlation matrix visualize_correlation_matrix(train_data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=saving_path) - visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path) + # visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path) # visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path) # Perform MFVI for your data K = 4 # Number of clusters - n_optimization_iterations = 100 miu_mfvi, pi_mfvi, resp_mfvi = perform_mfvi(train_data, K, n_optimization_iterations=1000, convergence_threshold=1e-5, run_until_convergence=False) # Calculate clustering metrics for MFVI