Skip to content

Commit

Permalink
defining better parameters
Browse files Browse the repository at this point in the history
Non relevant update
  • Loading branch information
lrm22005 committed Oct 24, 2023
1 parent 07b28c3 commit 49f69ae
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions project_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create_dataloader(data, batch_size=64, shuffle=True):
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
return data_loader

def visualize_trends(data, segment_names, num_plots=3):
def visualize_trends(data, segment_names, num_plots=3, save_path=None):
# Visualize random trends/segments
num_samples, _, _ = data.shape
for _ in range(num_plots):
Expand All @@ -114,6 +114,9 @@ def visualize_trends(data, segment_names, num_plots=3):
plt.imshow(data[idx].numpy())
plt.title(f"Segment: {segment_names[idx]}")
plt.colorbar()
if save_path:
subject_save_path = os.path.join(save_path, f"trends_visualization_segment_{segment_names}.png")
plt.savefig(subject_save_path)
plt.show()

def perform_pca(data, num_components=2):
Expand Down Expand Up @@ -169,7 +172,7 @@ def visualize_correlation_matrix(data, segment_names, subject_mode=True, num_sub
plt.title(f"Correlation Matrix for {num_subjects_to_visualize} Subjects {subject}")

if save_path:
subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_group{subject}.png")
subject_save_path = os.path.join(save_path, f"correlation_matrix_subject_group_{subject}.png")
plt.savefig(subject_save_path)

plt.show()
Expand Down Expand Up @@ -312,7 +315,7 @@ def main():

data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc)

train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=20, train=True, data_format=data_format)
train_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=10, train=True, data_format=data_format)
# test_data, _, _ = load_data(data_path, labels_path, dataset_size=30, train=False)

train_dataloader = create_dataloader(train_data)
Expand All @@ -327,12 +330,11 @@ def main():

# Visualize the correlation matrix
visualize_correlation_matrix(train_data, segment_names, subject_mode=True, num_subjects_to_visualize=None, save_path=saving_path)
visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path)
# visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path)
# visualize_correlation_matrix(train_data, segment_names, subject_mode=False, num_subjects_to_visualize=None, save_path=saving_path)

# Perform MFVI for your data
K = 4 # Number of clusters
n_optimization_iterations = 100
miu_mfvi, pi_mfvi, resp_mfvi = perform_mfvi(train_data, K, n_optimization_iterations=1000, convergence_threshold=1e-5, run_until_convergence=False)

# Calculate clustering metrics for MFVI
Expand Down

0 comments on commit 49f69ae

Please sign in to comment.