Skip to content

Commit

Permalink
MFVI Convergence
Browse files Browse the repository at this point in the history
This code has the Mean-Field convergence criteria to choose if use it or not and stop the code just if it converge.
  • Loading branch information
lrm22005 committed Oct 24, 2023
1 parent 5fc6745 commit f7b5b20
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions project_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def multivariate_normal_log_pdf_MFVI(x, mu, sigma_sq):

return log_p

def perform_mfvi(data, K, n_optimization_iterations):
def perform_mfvi(data, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=True):
N, D = data.shape[0], data.shape[1] * data.shape[2] # Calculate feature dimension D
# Define the variational parameters for the GMM
miu_variational = torch.randn(K, D, requires_grad=True)
Expand All @@ -160,7 +160,9 @@ def perform_mfvi(data, K, n_optimization_iterations):
# Define the optimizer for gradient descent
optimizer = torch.optim.Adam([miu_variational, log_sigma_variational, alpha_variational], lr=0.001)

for iteration in range(n_optimization_iterations):
prev_elbo = float('-inf')
iteration = 0
while True:
# Initialize gradients
optimizer.zero_grad()

Expand All @@ -187,6 +189,16 @@ def perform_mfvi(data, K, n_optimization_iterations):
if (iteration + 1) % 100 == 0:
print(f"Iteration {iteration + 1}/{n_optimization_iterations}")

if run_until_convergence:
if iteration > 0 and abs(elbo - prev_elbo) < convergence_threshold:
print(f"Converged after {iteration + 1} iterations")
break
elif iteration == n_optimization_iterations - 1:
print("Reached the specified number of iterations.")

prev_elbo = elbo
iteration += 1

# Extract the learned parameters
miu = miu_variational.detach().numpy()
pi = torch.softmax(alpha_variational, dim=0).detach().numpy()
Expand Down Expand Up @@ -290,7 +302,7 @@ def main():
# Perform MFVI for your data
K = 4 # Number of clusters
n_optimization_iterations = 100
miu_mfvi, pi_mfvi, resp_mfvi = perform_mfvi(train_data, K, n_optimization_iterations)
miu_mfvi, pi_mfvi, resp_mfvi = perform_mfvi(train_data, K, n_optimization_iterations, convergence_threshold=1e-5, run_until_convergence=False)

# Calculate clustering metrics for MFVI
zi_mfvi = np.argmax(resp_mfvi, axis=1)
Expand Down

0 comments on commit f7b5b20

Please sign in to comment.