diff --git a/project_NVIL.py b/project_NVIL.py index 2654939..eddc23a 100644 --- a/project_NVIL.py +++ b/project_NVIL.py @@ -61,6 +61,49 @@ def seed_all(seed=42): from sklearn.preprocessing import StandardScaler from PIL import Image +import logging +import numpy as np +import pandas as pd +import torch +from torch.distributions import MultivariateNormal +from torch.utils.data import DataLoader, TensorDataset +import matplotlib.pyplot as plt +from sklearn.decomposition import PCA, IncrementalPCA +from sklearn.manifold import TSNE +from sklearn.cluster import KMeans +from sklearn.cluster import MiniBatchKMeans +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import silhouette_score, adjusted_rand_score, adjusted_mutual_info_score, davies_bouldin_score +import seaborn as sns +from PIL import Image # Import the Image module + + +def get_data_paths(data_format, is_linux=False, is_hpc=False): + if is_linux: + base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch" + labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn" + saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis" + elif is_hpc: + base_path = "/gpfs/scratchfs1/kic14002/doh16101" + labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005" + saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis" + else: + # R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch + base_path = "R:\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch" + labels_base_path = "R:\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn" + saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case" + if data_format == 'csv': + data_path = os.path.join(base_path, "TFS_csv") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + elif data_format == 'png': + data_path = os.path.join(base_path, "TFS_plots") + labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm") + saving_path = os.path.join(saving_base_path, "Project_1_analysis") + else: + raise ValueError("Invalid data format. Choose 'csv' or 'png.") + return data_path, labels_path, saving_path + class CustomDataset(Dataset): def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv'): self.data_path = data_path @@ -70,103 +113,69 @@ def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format=' self.data_format = data_format self.transforms = ToTensor() - self.X_data_original, self.segment_names, self.labels = self.load_split_data() + # Extract unique segment names and their corresponding labels + self.segment_names, self.labels = self.extract_segment_names_and_labels() def __len__(self): return len(self.segment_names) def __getitem__(self, idx): segment_name = self.segment_names[idx] - time_freq_tensor = self.X_data_original[idx] - label = self.labels[segment_name] + # Load data on-the-fly based on the segment_name + time_freq_tensor = self.load_data(segment_name) + return {'data': time_freq_tensor.unsqueeze(0), 'label': label, 'segment_name': segment_name} - def load_split_data(self): - X_data_original = [] # Store original data without standardization + def extract_segment_names_and_labels(self): segment_names = [] - - for UID in self.UIDs: - data_path_UID = os.path.join(self.data_path, UID) - dir_list_seg = os.listdir(data_path_UID) - - for seg in dir_list_seg[:500]: # Limiting to the first 500 segments - seg_path = os.path.join(data_path_UID, seg) - - try: - if self.data_format == 'csv' and seg.endswith('.csv'): - time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) - time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) - elif self.data_format == 'png' and seg.endswith('.png'): - img = Image.open(seg_path) - img_data = np.array(img) - time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) - else: - continue # Skip other file formats - - X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data - - # Extract and store segment names - # Change here: Use the segment name from the CSV file directly - segment_names.append(seg.split('_filt')[0]) - - except Exception as e: - print(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}") - # You can also add more information to the error log, such as the value of time_freq_plot. - - X_data_original = torch.cat(X_data_original, 0) - - if self.standardize: - X_data_original = self.standard_scaling(X_data_original) # Standardize the data - - # Extract labels from CSV files - labels = self.extract_labels() - - important_labels = [0, 1, 2, 3] # List of important labels - - # Initialize labels for segments as unlabeled (-1) - segment_labels = {segment_name: -1 for segment_name in segment_names} - - for UID in labels.keys(): - if UID not in self.UIDs: - # Skip UIDs that are not in the dataset - continue - - label_data, label_segment_names = labels[UID] - - for idx, segment_label in enumerate(label_data): - segment_name = label_segment_names[idx] - - # Change here: Only update labels for segments present in the loaded data - if segment_name in segment_names: - if segment_label in important_labels: - segment_labels[segment_name] = segment_label - else: - # Set labels that are not in the important list as -1 (Unlabeled) - segment_labels[segment_name] = -1 - - return X_data_original, segment_names, segment_labels - - def extract_labels(self): labels = {} + for UID in self.UIDs: label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv") if os.path.exists(label_file): label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) - labels[UID] = (label_data['label'].values, label_segment_names.values) + for idx, segment_name in enumerate(label_segment_names): + if segment_name not in segment_names: + segment_names.append(segment_name) + labels[segment_name] = label_data['label'].values[idx] + + return segment_names, labels + + def load_data(self, segment_name): + data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0]) + seg_path = os.path.join(data_path_UID, segment_name + '_filt.csv') + + try: + if self.data_format == 'csv' and seg_path.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif self.data_format == 'png' and seg_path.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + raise ValueError("Unsupported file format") + + if self.standardize: + time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data - return labels + return time_freq_tensor.clone() + + except Exception as e: + print(f"Error processing segment: {segment_name}. Exception: {str(e)}") + return torch.zeros((1, 128, 128)) # Return zeros in case of an error def standard_scaling(self, data): scaler = StandardScaler() data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) return torch.Tensor(data) -def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=True, data_format='csv'): +def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv'): dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) return dataloader is_linux = False # Set to True if running on Linux, False if on Windows @@ -174,45 +183,49 @@ def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardiz data_format = 'csv' # Choose 'csv' or 'png' data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) +clinical_trial_train = [clinical_trial_train[0]] +clinical_trial_test = [clinical_trial_test[0]] +clinical_trial_unlabeled = [clinical_trial_unlabeled[0]] + # Example usage: -batch_size = 32 +batch_size = 64 train_loader = load_data_split_batched(data_path, labels_path, clinical_trial_train, batch_size) val_loader = load_data_split_batched(data_path, labels_path, clinical_trial_test, batch_size) test_loader = load_data_split_batched(data_path, labels_path, clinical_trial_unlabeled, batch_size) -def imshow(image, ax=None, title=None, normalize=True): - """Imshow for Tensor.""" - if ax is None: - fig, ax = plt.subplots() - image = image.numpy().transpose((1, 2, 0)) +# def imshow(image, ax=None, title=None, normalize=True): +# """Imshow for Tensor.""" +# if ax is None: +# fig, ax = plt.subplots() +# image = image.numpy().transpose((1, 2, 0)) - if normalize: - mean = np.array([0.485, 0.456, 0.406]) - std = np.array([0.229, 0.224, 0.225]) - image = std * image + mean - image = np.clip(image, 0, 1) +# if normalize: +# mean = np.array([0.485, 0.456, 0.406]) +# std = np.array([0.229, 0.224, 0.225]) +# image = std * image + mean +# image = np.clip(image, 0, 1) - ax.imshow(image) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.spines['left'].set_visible(False) - ax.spines['bottom'].set_visible(False) - ax.tick_params(axis='both', length=0) - ax.set_xticklabels('') - ax.set_yticklabels('') +# ax.imshow(image) +# ax.spines['top'].set_visible(False) +# ax.spines['right'].set_visible(False) +# ax.spines['left'].set_visible(False) +# ax.spines['bottom'].set_visible(False) +# ax.tick_params(axis='both', length=0) +# ax.set_xticklabels('') +# ax.set_yticklabels('') - return ax +# return ax -# change this to the trainloader or testloader -data_iter = iter(train_loader) -data = next(data_iter) -fig, axes = plt.subplots(figsize=(20,15), ncols=6) -for ii in range(6): - images, label, names = data["data"], data["label"], data["segment_name"] - ax = axes[ii] -# helper.imshow(images[ii], ax=ax, normalize=False) - imshow(images[ii], ax=ax, normalize=False) +# # change this to the trainloader or testloader +# data_iter = iter(train_loader) +# data = next(data_iter) +# fig, axes = plt.subplots(figsize=(20,15), ncols=6) +# for ii in range(6): +# images, label, names = data["data"], data["label"], data["segment_name"] +# ax = axes[ii] +# # helper.imshow(images[ii], ax=ax, normalize=False) +# imshow(images[ii], ax=ax, normalize=False) @@ -248,7 +261,7 @@ class MoGPriorNet(PriorNet): = \sum_c w_c prod_k Normal(z[k]|u[c], s[c]^2) """ - def __init__(self, outcome_shape, num_components, lbound=-10, rbound=10): + def __init__(self, outcome_shape, num_components, lbound=-83.1480, rbound=1.3803): super().__init__(outcome_shape) # [C] self.logits = nn.Parameter(torch.rand(num_components, requires_grad=True), requires_grad=True) @@ -272,6 +285,22 @@ def forward(self, batch_shape): # and finally, a mixture return td.MixtureSameFamily(pc, comps) + # def sample(self, shape): + # """ + # Return a sample from the distribution parameterized by the model. + # """ + # # Draw a sample from the categorical distribution over components + # pc = td.Categorical(logits=self.logits) + # component_indices = pc.sample(shape) + + # # Draw samples from the component distributions + # comps = td.Normal( + # loc=self.locs[component_indices], + # scale=self.scales[component_indices] + # ) + + # return comps.sample() + def test_priors(batch_size=2, latent_dim=3, num_comps=5): prior_net = MoGPriorNet(latent_dim, num_comps) print("\nMixture of Gaussian") @@ -285,7 +314,7 @@ def test_priors(batch_size=2, latent_dim=3, num_comps=5): print(f" shapes: sample={z.shape} log_prob={p.log_prob(z).shape}") -test_priors() +# test_priors() # Conditional probability distributions @@ -385,41 +414,12 @@ def forward(self, input): # flatten batch shape and obtain outputs output = super().forward(input.reshape((-1,) + event_shape)) - print("Shape after FlattenImage:", input.shape) + # print("Shape after FlattenImage:", input.shape) # restore batch shape output_shape = batch_shape + output.shape[1:] return output.reshape(output_shape) -# class MySequential(nn.Sequential): -# """ -# This is a version of nn.Sequential that works with structured batches -# (i.e., batches that have multiple dimensions) -# even when some of the nn layers in it does not. - -# The idea is to just wrap nn.Sequential around two calls to reshape -# which remove and restore the batch dimensions. -# """ - -# def __init__(self, *args, event_dims=1): -# super().__init__(*args) -# self._event_dims = event_dims - -# def forward(self, input): -# # memorize batch shape -# batch_shape = input.shape[:-self._event_dims] -# # memorize latent shape -# event_shape = input.shape[-self._event_dims:] - -# # Print the shape of the tensor after the FlattenImage layer -# input = super().forward(input) -# print("Shape after FlattenImage:", input.shape) - -# # flatten batch shape and obtain outputs -# output = input.reshape((-1,) + event_shape) -# # restore batch shape -# return output.reshape(batch_shape + output.shape[1:]) - def build_cnn_decoder(latent_size, num_channels, width=64, height=64, hidden_size=1024, p_drop=0.): """ Map the latent code to a tensor with shape [num_channels, width, height]. @@ -448,16 +448,18 @@ def build_cnn_decoder(latent_size, num_channels, width=64, height=64, hidden_siz ) return decoder -# # Test the modified function -img_shape = train_data[0].shape -num_channels, width, height = img_shape[0], img_shape[1], img_shape[2] -latent_size = 128 +# data_iter = iter(train_loader) +# data = next(data_iter) +# # # Test the modified function +# img_shape = data["data"][0].shape +# num_channels, width, height = img_shape[0], img_shape[1], img_shape[2] +# latent_size = 128 -output_shape = build_cnn_decoder(latent_size=latent_size, num_channels=num_channels, width=width, height=height)( - torch.zeros((128, latent_size)) -).shape +# output_shape = build_cnn_decoder(latent_size=latent_size, num_channels=num_channels, width=width, height=height)( +# torch.zeros((128, latent_size)) +# ).shape -print(output_shape) +# print(output_shape) # note that because we use MySequential, # we can have a batch of [3, 5] assignments @@ -489,18 +491,18 @@ def forward(self, z): h = self.decoder(z) return td.Independent(td.ContinuousBernoulli(logits=h), len(self.outcome_shape)) -obs_model = ContinuousImageModel( - num_channels=img_shape[0], - width=img_shape[1], - height=img_shape[2], - latent_size=128, - p_drop=0.1, - decoder_type=build_cnn_decoder -) -print(obs_model) -# a batch of five zs is mapped to 5 distributions over [1,64,64]-dimensional -# binary tensors -print(obs_model(torch.zeros([128, 128]))) +# obs_model = ContinuousImageModel( +# num_channels=img_shape[0], +# width=img_shape[1], +# height=img_shape[2], +# latent_size=128, +# p_drop=0.1, +# decoder_type=build_cnn_decoder +# ) +# print(obs_model) +# # a batch of five zs is mapped to 5 distributions over [1,64,64]-dimensional +# # binary tensors +# print(obs_model(torch.zeros([128, 128]))) # Joint distribution @@ -635,7 +637,7 @@ def test_joint_dist(latent_size=10, num_comps=3, data_shape=(1, 128, 128), batch print(" 1:", p.naive_lowerbound(x, 10)) print(" 2:", p.naive_lowerbound(x, 10)) -test_joint_dist(10) +# test_joint_dist(10) ##### PART 2 @@ -675,16 +677,16 @@ def build_cnn_encoder(num_channels, width=64, height=64, output_size=1024, p_dro return encoder # Example usage with width=128, height=128, and variable output_size -width = 128 -height = 128 -output_size = 1024 # You can set this to any desired output size -encoder = build_cnn_encoder(num_channels=1, width=width, height=height, output_size=output_size) -# a batch of five [1, 128, 128]-dimensional images is encoded into -# five `output_size`-dimensional vectors -encoder(torch.zeros((5, 1, width, height))).shape -# and, again, since we use MySequential we can have structured batches -# (here trying with (3,5)) -build_cnn_encoder(num_channels=1, width=width, height=height, output_size=output_size)(torch.zeros((3, 5, 1, 128, 128))).shape +# width = 128 +# height = 128 +# output_size = 1024 # You can set this to any desired output size +# encoder = build_cnn_encoder(num_channels=1, width=width, height=height, output_size=output_size) +# # a batch of five [1, 128, 128]-dimensional images is encoded into +# # five `output_size`-dimensional vectors +# encoder(torch.zeros((5, 1, width, height))).shape +# # and, again, since we use MySequential we can have structured batches +# # (here trying with (3,5)) +# build_cnn_encoder(num_channels=1, width=width, height=height, output_size=output_size)(torch.zeros((3, 5, 1, 128, 128))).shape # Mixture of Gaussian mean fields @@ -698,7 +700,7 @@ class MoGCPDNet(CPDNet): Output distribution is a mixture of products of Gaussian distributions """ - def __init__(self, outcome_shape, num_inputs: int, hidden_size: int=None, p_drop: float=0., num_components=2): + def __init__(self, outcome_shape, num_inputs: int, hidden_size: int=None, p_drop: float=0., num_components=4): """ outcome_shape: shape of the outcome (int or tuple) if int, we turn it into a singleton tuple @@ -800,8 +802,8 @@ def forward(self, x): h = self.encoder(x) return self.cpd_net(h) -InferenceModel(partial(MoGCPDNet, num_components=3), latent_size=10)(torch.zeros(5, 1, 128, 128)) -InferenceModel(partial(MoGCPDNet, num_components=3), latent_size=10)(torch.zeros(5, 1, 128, 128)) +# InferenceModel(partial(MoGCPDNet, num_components=3), latent_size=10)(torch.zeros(5, 1, 128, 128)) +# InferenceModel(partial(MoGCPDNet, num_components=3), latent_size=10)(torch.zeros(5, 1, 128, 128)) # Neural Variational Inference @@ -1016,10 +1018,11 @@ def log_prob_estimate(self, x, sample_size=None): _, _, L = self.DRL(x, sample_size=sample_size) return L + def forward(self, x, sample_size=None, rate_weight=1.): """ A surrogate for an MC estimate of - grad ELBO - + x: [batch_size] + data_shape sample_size: if 1 or more, we use multiple samples sample_size controls a sequential computation (a for loop) @@ -1028,18 +1031,18 @@ def forward(self, x, sample_size=None, rate_weight=1.): sample_size = sample_size or 1 obs_dims = len(self.gen_model.cpd_net.outcome_shape) batch_shape = x.shape[:-obs_dims] - + qz = self.inf_model(x) pz = self.gen_model.prior(batch_shape) - + # we can *always* make use of the score function estimator (SFE) use_sfe = True - + # these 3 log densities will contribute to the different parts of the objective log_p_x_z = 0. log_p_z = 0. log_q_z_x = 0. - + # these quantities will help us compute the SFE part of the objective # (if needed) sfe = 0 @@ -1047,19 +1050,24 @@ def forward(self, x, sample_size=None, rate_weight=1.): cv_reward = 0 raw_r = 0 cv_loss = 0 - + for _ in range(sample_size): - + # Obtain a sample if qz.has_rsample: # this is how td objects tell us whether they are continuously reparameterisable z = qz.rsample() use_sfe = False # with path derivatives, we do not need SFE else: z = qz.sample() - + # Parameterise the observational model px_z = self.gen_model.obs_model(z) - + + # Print shapes for debugging + print("x.shape:", x.shape) + print("z.shape:", z.shape) + print("px_z.log_prob(x).shape:", px_z.log_prob(x).shape) + # Compute all three relevant densities: # p(x|z,theta) log_p_x_z = log_p_x_z + px_z.log_prob(x) @@ -1067,30 +1075,37 @@ def forward(self, x, sample_size=None, rate_weight=1.): log_q_z_x = log_q_z_x + qz.log_prob(z) # p(z|theta) log_p_z = log_p_z + pz.log_prob(z) - + # Compute the "reward" for SFE raw_r = log_p_x_z + log_p_z - log_q_z_x - + # Apply variance reduction techniques r, l = self.cv_model(raw_r.detach(), x=x, q=qz, r_fn=lambda a: self.gen_model(a).log_prob(x)) cv_loss = cv_loss + l - + # SFE part for updating lambda sfe = sfe + r.detach() * qz.log_prob(z) - + + # Print shapes for debugging + print("sfe.shape:", sfe.shape) + print("cv_loss.shape:", cv_loss.shape) + print("log_p_x_z.shape:", log_p_x_z.shape) + print("log_p_z.shape:", log_p_z.shape) + print("log_q_z_x.shape:", log_q_z_x.shape) + # Compute the sample mean for the different terms sfe = (sfe / sample_size) cv_loss = cv_loss / sample_size log_p_x_z = log_p_x_z / sample_size log_p_z = log_p_z / sample_size log_q_z_x = log_q_z_x / sample_size - + D = - log_p_x_z try: # not every design admits tractable KL R = td.kl_divergence(qz, pz) except NotImplementedError: R = log_q_z_x - log_p_z - + if use_sfe: # the first two terms update theta # the last term updates lambda @@ -1101,37 +1116,40 @@ def forward(self, x, sample_size=None, rate_weight=1.): else: # without SFE, we can use the classic form of the ELBO elbo_grad_surrogate = -D - R - + loss = -elbo_grad_surrogate + cv_loss - + return {'loss': loss.mean(0), 'ELBO': (-D -R).mean(0).item(), 'D': D.mean(0).item(), 'R': R.mean(0).item(), 'cv_loss': cv_loss.mean(0).item()} -num_classes = 4 # Change this to the actual number of classes -vae = NVIL( - JointDistribution( - MoGPriorNet(128, 128), - MoGCPDNet( - outcome_shape=img_shape, - num_inputs=batch_size, - hidden_size=32, - num_components=2 - ) - ), - InferenceModel( - cpd_net_type=partial(MoGCPDNet), - latent_size=128, - num_channels=img_shape[0], - width=img_shape[1], - height=img_shape[2], - ), - VarianceReduction() -) -vae -for x, y in train_loader: - print('x.shape:', x.shape) - print(vae(x)) - break +# num_classes = 4 # Change this to the actual number of classes +# vae = NVIL( +# JointDistribution( +# MoGPriorNet(32, 32), +# MoGCPDNet( +# outcome_shape=img_shape, +# num_inputs=batch_size, +# hidden_size=32, +# num_components=2 +# ) +# ), +# InferenceModel( +# cpd_net_type=partial(MoGCPDNet), +# latent_size=32, +# num_channels=img_shape[0], +# width=img_shape[1], +# height=img_shape[2], +# ), +# VarianceReduction() +# ) +# vae + + +# for x in train_loader: +# x_1 = x["data"] +# print('x_1.shape:', x_1.shape) +# print(vae(x_1)) +# break # Training algorithm @@ -1168,8 +1186,11 @@ def assess(model, sample_size, dl, device): L = 0 data_size = 0 with torch.no_grad(): - for batch_x, batch_y in dl: - Dx, Rx, Lx = model.DRL(batch_x.to(device), sample_size=sample_size) + for batch in dl: + batch_x = batch['data'].to(device) + batch_y = batch['label'].to(device) + + Dx, Rx, Lx = model.DRL(batch_x, sample_size=sample_size) D = D + Dx.sum(0) R = R + Rx.sum(0) L = L + Lx.sum(0) @@ -1180,46 +1201,46 @@ def assess(model, sample_size, dl, device): return {'ELBO': (-D -R).item(), 'D': D.item(), 'R': R.item(), 'L': L.item()} +from tqdm import tqdm + def train_vae(model: NVIL, opts: OptCollection, - training_data, dev_data, - batch_size=64, num_epochs=10, check_every=10, + training_loader, dev_loader, + num_epochs=10, check_every=10, sample_size_training=1, sample_size_eval=10, grad_clip=5., - num_workers=2, device=torch.device('cuda:0') ): """ model: pytorch model optimiser: pytorch optimiser - training_corpus: a TaggedCorpus for trianing - dev_corpus: a TaggedCorpus for dev - batch_size: use more if you have more memory + training_loader: DataLoader for training + dev_loader: DataLoader for dev num_epochs: use more for improved convergence check_every: use less to check performance on dev set more often device: where we run the experiment Return a log of quantities computed during training (for plotting) """ - batcher = DataLoader(training_data, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) - dev_batcher = DataLoader(dev_data, batch_size, num_workers=num_workers, pin_memory=True) - - total_steps = num_epochs * len(batcher) + total_steps = num_epochs * len(training_loader) log = defaultdict(list) step = 0 model.eval() - for k, v in assess(model, sample_size_eval, dev_batcher, device=device).items(): + for k, v in assess(model, sample_size_eval, dev_loader, device=device).items(): log[f"dev.{k}"].append((step, v)) with tqdm(range(total_steps)) as bar: for epoch in range(num_epochs): - for batch_x, batch_y in batcher: + for batch in training_loader: model.train() opts.zero_grad() + batch_x = batch['data'].to(device) + batch_y = batch['label'].to(device) + loss_dict = model( - batch_x.to(device), + batch_x, sample_size=sample_size_training, ) for metric, value in loss_dict.items(): @@ -1243,40 +1264,82 @@ def train_vae(model: NVIL, opts: OptCollection, if step % check_every == 0: model.eval() - for k, v in assess(model, sample_size_eval, dev_batcher, device=device).items(): + for k, v in assess(model, sample_size_eval, dev_loader, device=device).items(): log[f"dev.{k}"].append((step, v)) step += 1 - model.eval() - for k, v in assess(model, sample_size_eval, dev_batcher, device=device).items(): + for k, v in assess(model, sample_size_eval, dev_loader, device=device).items(): log[f"dev.{k}"].append((step, v)) return log # And, finally, some code to help inspect samples +def show_image(image, ax=None, title=None, normalize=False): + """Show image for Tensor.""" + if ax is None: + fig, ax = plt.subplots() + image = image.numpy().transpose((1, 2, 0)) + + if normalize: + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + image = std * image + mean + image = np.clip(image, 0, 1) + + ax.imshow(image) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['left'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.tick_params(axis='both', length=0) + ax.set_xticklabels('') + ax.set_yticklabels('') + + return ax +# # change this to the trainloader or testloader +# data_iter = iter(train_loader) +# data = next(data_iter) +# fig, axes = plt.subplots(figsize=(20, 15), nrows=5, ncols=4) +# axes = axes.flatten() # Flatten the 2D array of axes +# for ii in range(20): +# images, label, names = data["data"], data["label"], data["segment_name"] +# ax = axes[ii] +# show_image(images[ii], ax=ax, normalize=False) + + def inspect_lvm(model, dl, device): - for x, y in dl: - - x_ = model.sample(16, 4, oversample=True).cpu().reshape(-1, 1, 64, 64) - plt.figure(figsize=(16,8)) + for batch in dl: + batch_x = batch['data'] + batch_y = batch['label'] + x_ = model.sample(16, 4, oversample=True).cpu().reshape(-1, 1, 128, 128) + fig, axes = plt.subplots(figsize=(20, 15), nrows=2, ncols=4) plt.axis('off') - plt.imshow(make_grid(x_, nrow=16).permute((1, 2, 0))) + axes = axes.flatten() # Flatten the 2D array of axes + for ii in range(8): + ax = axes[ii] + show_image(x_[ii], ax=ax, normalize=False) plt.title("Prior samples") plt.show() - plt.figure(figsize=(16,8)) + fig, axes = plt.subplots(figsize=(20, 15), nrows=2, ncols=4) plt.axis('off') - plt.imshow(make_grid(x, nrow=16).permute((1, 2, 0))) + axes = axes.flatten() # Flatten the 2D array of axes + for ii in range(8): + ax = axes[ii] + show_image(batch_x[ii], ax=ax, normalize=False) plt.title("Observations") plt.show() - x_ = model.cond_sample(x.to(device)).cpu().reshape(-1, 1, 64, 64) - plt.figure(figsize=(16,8)) + x_ = model.cond_sample(batch_x.to(device)).cpu().reshape(-1, 1, 128, 128) + fig, axes = plt.subplots(figsize=(20, 15), nrows=2, ncols=4) plt.axis('off') - plt.imshow(make_grid(x_, nrow=16).permute((1, 2, 0))) + axes = axes.flatten() # Flatten the 2D array of axes + for ii in range(8): + ax = axes[ii] + show_image(make_grid(x_[ii]), ax=ax, normalize=False) plt.title("Conditional samples") plt.show() @@ -1410,28 +1473,32 @@ def forward(self, r, x, q, r_fn): seed_all() my_device = torch.device('cuda:0') +data_iter = iter(train_loader) +data = next(data_iter) + +# # Test the modified function +img_shape = data["data"][0].shape model = NVIL( JointDistribution( - MoGPriorNet(128, 128), + MoGPriorNet(64, 64), MoGCPDNet( outcome_shape=img_shape, - num_inputs=batch_size, - hidden_size=32, - num_components=2 + num_inputs=64, + hidden_size=None, + num_components=8 ) ), - InferenceModel( - latent_size=128, + InferenceModel(partial(MoGCPDNet), + latent_size=64, num_channels=img_shape[0], width=img_shape[1], height=img_shape[2], - cpd_net_type=partial(MoGCPDNet) # Gaussian prior and Gaussian posterior: this is a classic VAE - ), + encoder_type=build_cnn_encoder), # Gaussian prior and Gaussian posterior: this is a classic VAE # VarianceReduction(), # no variance reduction is needed for a VAE CVChain( # variance reduction helps SFE CentredReward(), - Baseline(np.prod(img_shape), 512), # this is how you would use a trained baselined + Baseline(np.prod(img_shape), 1024), # this is how you would use a trained baselined ScaledReward() ) ).to(my_device) @@ -1442,32 +1509,39 @@ def forward(self, r, x, q, r_fn): # Adam is the go-to choice for (reparameterised) VAEs opt.Adam(model.gen_params(), lr=5e-4, weight_decay=1e-6), opt.Adam(model.inf_params(), lr=1e-4), + opt.Adam(model.cv_params(), lr=1e-4, weight_decay=1e-6), # Adam is not often a good choice for SFE-based optimisation # a possible reason: SFE is too noisy and the design choices behind Adam # were made having reparameterised gradients in mind - #opt.RMSprop(model.gen_params(), lr=5e-4, weight_decay=1e-6), - #opt.RMSprop(model.inf_params(), lr=1e-4), - #opt.RMSprop(model.cv_params(), lr=1e-4, weight_decay=1e-6) # you need this if your baseline has trainable parameters + # opt.RMSprop(model.gen_params(), lr=5e-4, weight_decay=1e-6), + # opt.RMSprop(model.inf_params(), lr=1e-4), + # opt.RMSprop(model.cv_params(), lr=1e-4, weight_decay=1e-6) # you need this if your baseline has trainable parameters ) -model +# model +# Set other parameters such as batch size, sample sizes, etc. +batch_size = 64 +num_epochs = 20 +check_every = 10 +sample_size_training = 8 +sample_size_eval = 8 +grad_clip = 5. +device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + +# Train the VAE +log = train_vae(model, opts, train_loader, val_loader, num_epochs, check_every, + sample_size_training, sample_size_eval, grad_clip, device) + +import gc + +model.cpu() +del model +gc.collect() +torch.cuda.empty_cache() -log = train_vae( - model=model, - opts=opts, - training_data=combined_train_data, - dev_data=combined_val_data, - batch_size=64, - num_epochs=100, # use more for better models - check_every=100, - sample_size_training=1, - sample_size_eval=1, - grad_clip=5., - device=my_device -) log.keys() @@ -1514,7 +1588,7 @@ def forward(self, r, x, q, r_fn): fig.tight_layout(h_pad=2, w_pad=2) -inspect_lvm(model, DataLoader(combined_val_data, 128, num_workers=2, pin_memory=True), my_device) +inspect_lvm(model, val_loader, my_device) @@ -1525,76 +1599,46 @@ def forward(self, r, x, q, r_fn): # Set up your optimizer -gen_optimizer = opt.Adam(model.gen_params(), lr=5e-4, weight_decay=1e-6) -inf_optimizer = opt.Adam(model.inf_params(), lr=1e-4) -cv_optimizer = opt.Adam(model.cv_params(), lr=1e-4, weight_decay=1e-6) - -opts = OptCollection(gen_optimizer, inf_optimizer, cv_optimizer) +from tqdm import tqdm -def train_semisupervised_vae(model: NVIL, opts: OptCollection, train_loader, val_loader, test_loader, +def train_semisupervised_classifier(model: NVIL, opts: OptCollection, + labeled_train_loader, unlabeled_train_loader, unlabeled_val_loader, unlabeled_test_loader, num_epochs=10, check_every=10, sample_size_training=1, sample_size_eval=10, grad_clip=5., device=torch.device('cuda:0') ): - """ - Train a semi-supervised NVIL model. - - model: NVIL model - opts: OptCollection containing optimizers - train_loader: DataLoader for combined labeled and unlabeled data - val_loader: DataLoader for validation data - test_loader: DataLoader for test data - num_epochs: number of training epochs - check_every: frequency to check performance on dev set - sample_size_training: number of samples for training - sample_size_eval: number of samples for evaluation - grad_clip: gradient clipping value - device: device for training - - Returns a log of quantities computed during training (for plotting) - """ - - total_steps = num_epochs * len(train_loader) + total_steps = num_epochs * len(labeled_train_loader) log = defaultdict(list) step = 0 model.eval() - for k, v in assess(model, sample_size_eval, val_loader, device=device).items(): + + # Use unlabeled_val_loader for validation since there is no labeled validation data + for k, v in assess(model, sample_size_eval, unlabeled_val_loader, device=device).items(): log[f"dev.{k}"].append((step, v)) with tqdm(range(total_steps)) as bar: for epoch in range(num_epochs): - for batch_x, batch_y in train_loader: + for labeled_batch, unlabeled_batch in zip(labeled_train_loader, unlabeled_train_loader): model.train() opts.zero_grad() - # Check if the data is labeled or unlabeled - labeled_mask = (batch_y != -1).to(device) - - # Forward pass for unlabeled data - unlabeled_loss_dict = model( - batch_x.to(device), - sample_size=sample_size_training, - ) + labeled_batch_x = labeled_batch['data'].to(device) + labeled_batch_y = labeled_batch['label'].to(device) - # Forward pass for labeled data with classification loss - labeled_loss_dict = model( - batch_x[labeled_mask].to(device), - sample_size=sample_size_training, - ) + unlabeled_batch_x = unlabeled_batch['data'].to(device) + unlabeled_batch_y = unlabeled_batch['label'].to(device) - # Classification loss only for labeled data - classification_loss = F.cross_entropy( - labeled_loss_dict['logits'].squeeze(), batch_y[labeled_mask].to(device) - ) + labeled_loss_dict = model(labeled_batch_x, labeled_batch_y) + unlabeled_loss_dict = model(unlabeled_batch_x) - # Combine losses - total_loss = unlabeled_loss_dict['loss'] + classification_loss + for metric, value in labeled_loss_dict.items(): + log[f'training.{metric}'].append((step, value)) - total_loss.backward() + labeled_loss_dict['loss'].backward() nn.utils.clip_grad_norm_( model.parameters(), @@ -1603,26 +1647,54 @@ def train_semisupervised_vae(model: NVIL, opts: OptCollection, train_loader, val opts.step() bar_dict = OrderedDict() - bar_dict['training.loss'] = f"{total_loss:.2f}" - for metric, value in unlabeled_loss_dict.items(): - bar_dict[f'training.{metric}'] = f"{unlabeled_loss_dict[metric]:.2f}" - for metric in ['ELBO', 'D', 'R', 'L']: + for metric, value in labeled_loss_dict.items(): + bar_dict[f'training.{metric}'] = f"{labeled_loss_dict[metric]:.2f}" + for metric in ['accuracy', 'precision', 'recall', 'f1']: bar_dict[f"dev.{metric}"] = "{:.2f}".format(log[f"dev.{metric}"][-1][1]) bar.set_postfix(bar_dict) bar.update() if step % check_every == 0: model.eval() - for k, v in assess(model, sample_size_eval, val_loader, device=device).items(): + for k, v in assess(model, sample_size_eval, unlabeled_val_loader, device=device).items(): log[f"dev.{k}"].append((step, v)) step += 1 model.eval() - for k, v in assess(model, sample_size_eval, test_loader, device=device).items(): - log[f"test.{k}"].append((step, v)) + + # Use unlabeled_val_loader for validation since there is no labeled validation data + for k, v in assess(model, sample_size_eval, unlabeled_val_loader, device=device).items(): + log[f"dev.{k}"].append((step, v)) + + # Evaluate on the unlabeled test set + unlabeled_test_metrics = assess(model, sample_size_eval, unlabeled_test_loader, device=device) + log.update({f"unlabeled_test.{k}": [(step, v)] for k, v in unlabeled_test_metrics.items()}) return log -# Assuming you have the NVIL model (nvil), optimizers (opts), and data loaders (train_loader, val_loader, test_loader) -log = train_semisupervised_vae(nvil, opts, train_loader, val_loader, test_loader) +# Example usage + +from torch.utils.data import DataLoader, random_split, Subset + +# Assuming you have labeled and unlabeled indices for training and validation +labeled_train_indices = [i for i, data in enumerate(train_loader.dataset) if data['label'] != -1] +unlabeled_train_indices = [i for i, data in enumerate(train_loader.dataset) if data['label'] == -1] + +labeled_val_indices = [i for i, data in enumerate(val_loader.dataset) if data['label'] != -1] +unlabeled_val_indices = [i for i, data in enumerate(val_loader.dataset) if data['label'] == -1] + +# Assuming your test loader only contains unlabeled data +unlabeled_test_indices = [i for i, data in enumerate(test_loader.dataset) if data['label'] == -1] + +# Creating separate loaders for labeled and unlabeled data +labeled_train_loader = DataLoader(Subset(train_loader.dataset, labeled_train_indices), batch_size=batch_size, shuffle=True) +unlabeled_train_loader = DataLoader(Subset(train_loader.dataset, unlabeled_train_indices), batch_size=batch_size, shuffle=True) + +labeled_val_loader = DataLoader(Subset(val_loader.dataset, labeled_val_indices), batch_size=batch_size, shuffle=False) +unlabeled_val_loader = DataLoader(Subset(val_loader.dataset, unlabeled_val_indices), batch_size=batch_size, shuffle=False) + +unlabeled_test_loader = DataLoader(Subset(test_loader.dataset, unlabeled_test_indices), batch_size=batch_size, shuffle=False) + +log = train_semisupervised_classifier(model, opts, labeled_train_loader, unlabeled_train_loader, unlabeled_val_loader, unlabeled_test_loader, + num_epochs, check_every, sample_size_training, sample_size_eval, grad_clip, device) \ No newline at end of file