-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from lrm22005/Luis
error in file names ..._STFT.csv
- Loading branch information
Showing
2 changed files
with
96 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Wed Dec 6 22:07:18 2023 | ||
@author: lrm22005 | ||
""" | ||
|
||
import os | ||
import pandas as pd | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset, DataLoader | ||
from torchvision.transforms import ToTensor | ||
from sklearn.preprocessing import StandardScaler | ||
from PIL import Image | ||
|
||
class CustomDataset(Dataset): | ||
def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv'): | ||
self.data_path = data_path | ||
self.labels_path = labels_path | ||
self.UIDs = UIDs | ||
self.standardize = standardize | ||
self.data_format = data_format | ||
self.transforms = ToTensor() | ||
|
||
# Extract unique segment names and their corresponding labels | ||
self.segment_names, self.labels = self.extract_segment_names_and_labels() | ||
|
||
def __len__(self): | ||
return len(self.segment_names) | ||
|
||
def __getitem__(self, idx): | ||
segment_name = self.segment_names[idx] | ||
label = self.labels[segment_name] | ||
|
||
# Load data on-the-fly based on the segment_name | ||
time_freq_tensor = self.load_data(segment_name) | ||
|
||
return {'data': time_freq_tensor.unsqueeze(0), 'label': label, 'segment_name': segment_name} | ||
|
||
def extract_segment_names_and_labels(self): | ||
segment_names = [] | ||
labels = {} | ||
|
||
for UID in self.UIDs: | ||
label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv") | ||
if os.path.exists(label_file): | ||
label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label']) | ||
label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0]) | ||
for idx, segment_name in enumerate(label_segment_names): | ||
if segment_name not in segment_names: | ||
segment_names.append(segment_name) | ||
labels[segment_name] = label_data['label'].values[idx] | ||
|
||
return segment_names, labels | ||
|
||
def load_data(self, segment_name): | ||
data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0]) | ||
seg_path = os.path.join(data_path_UID, segment_name + '_filt.csv') | ||
|
||
try: | ||
if self.data_format == 'csv' and seg_path.endswith('.csv'): | ||
time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) | ||
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) | ||
elif self.data_format == 'png' and seg_path.endswith('.png'): | ||
img = Image.open(seg_path) | ||
img_data = np.array(img) | ||
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) | ||
else: | ||
raise ValueError("Unsupported file format") | ||
|
||
if self.standardize: | ||
time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data | ||
|
||
return time_freq_tensor.clone() | ||
|
||
except Exception as e: | ||
print(f"Error processing segment: {segment_name}. Exception: {str(e)}") | ||
return torch.zeros((1, 128, 128)) # Return zeros in case of an error | ||
|
||
def standard_scaling(self, data): | ||
scaler = StandardScaler() | ||
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape) | ||
return torch.Tensor(data) | ||
|
||
def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv'): | ||
dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format) | ||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) | ||
return dataloader | ||
|
||
# To validate the len of the dataloader | ||
# number of examples | ||
# len(train_loader) | ||
# number of batches | ||
# len(train_loader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters