From 123c76f51a83d469f781ba3e1b40aa5a3dd0df45 Mon Sep 17 00:00:00 2001 From: Luis Roberto Mercado Diaz Date: Thu, 7 Dec 2023 15:27:33 -0500 Subject: [PATCH] error in file names ..._STFT.csv Nothing important --- option.py | 95 +++++++++++++++++++++++++++++++++++++++++++++++++ project_NVIL.py | 2 +- 2 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 option.py diff --git a/option.py b/option.py new file mode 100644 index 0000000..c1d8352 --- /dev/null +++ b/option.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Dec 6 22:07:18 2023 + +@author: lrm22005 +""" + +import os +import pandas as pd +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision.transforms import ToTensor +from sklearn.preprocessing import StandardScaler +from PIL import Image + +class CustomDataset(Dataset): + def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv'): + self.data_path = data_path + self.labels_path = labels_path + self.UIDs = UIDs + self.standardize = standardize + self.data_format = data_format + self.transforms = ToTensor() + + # Extract unique segment names and their corresponding labels + self.segment_names, self.labels = self.extract_segment_names_and_labels() + + def __len__(self): + return len(self.segment_names) + + def __getitem__(self, idx): + segment_name = self.segment_names[idx] + label = self.labels[segment_name] + + # Load data on-the-fly based on the segment_name + time_freq_tensor = self.load_data(segment_name) + + return {'data': time_freq_tensor.unsqueeze(0), 'label': label, 'segment_name': segment_name} + + def extract_segment_names_and_labels(self): + segment_names = [] + labels = {} + + for UID in self.UIDs: + label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv") + if os.path.exists(label_file): + label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) + label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) + for idx, segment_name in enumerate(label_segment_names): + if segment_name not in segment_names: + segment_names.append(segment_name) + labels[segment_name] = label_data['label'].values[idx] + + return segment_names, labels + + def load_data(self, segment_name): + data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0]) + seg_path = os.path.join(data_path_UID, segment_name + '_filt.csv') + + try: + if self.data_format == 'csv' and seg_path.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif self.data_format == 'png' and seg_path.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + raise ValueError("Unsupported file format") + + if self.standardize: + time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data + + return time_freq_tensor.clone() + + except Exception as e: + print(f"Error processing segment: {segment_name}. Exception: {str(e)}") + return torch.zeros((1, 128, 128)) # Return zeros in case of an error + + def standard_scaling(self, data): + scaler = StandardScaler() + data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) + return torch.Tensor(data) + +def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv'): + dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) + return dataloader + +# To validate the len of the dataloader +# number of examples +# len(train_loader) +# number of batches +# len(train_loader) diff --git a/project_NVIL.py b/project_NVIL.py index eddc23a..d54789e 100644 --- a/project_NVIL.py +++ b/project_NVIL.py @@ -146,7 +146,7 @@ def extract_segment_names_and_labels(self): def load_data(self, segment_name): data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0]) - seg_path = os.path.join(data_path_UID, segment_name + '_filt.csv') + seg_path = os.path.join(data_path_UID, segment_name + '_filt_STFT.csv') try: if self.data_format == 'csv' and seg_path.endswith('.csv'):