Skip to content

Commit

Permalink
error in file names ..._STFT.csv
Browse files Browse the repository at this point in the history
Nothing important
  • Loading branch information
lrm22005 committed Dec 7, 2023
1 parent 1914d44 commit 123c76f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
95 changes: 95 additions & 0 deletions option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
"""
Created on Wed Dec 6 22:07:18 2023
@author: lrm22005
"""

import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from sklearn.preprocessing import StandardScaler
from PIL import Image

class CustomDataset(Dataset):
def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv'):
self.data_path = data_path
self.labels_path = labels_path
self.UIDs = UIDs
self.standardize = standardize
self.data_format = data_format
self.transforms = ToTensor()

# Extract unique segment names and their corresponding labels
self.segment_names, self.labels = self.extract_segment_names_and_labels()

def __len__(self):
return len(self.segment_names)

def __getitem__(self, idx):
segment_name = self.segment_names[idx]
label = self.labels[segment_name]

# Load data on-the-fly based on the segment_name
time_freq_tensor = self.load_data(segment_name)

return {'data': time_freq_tensor.unsqueeze(0), 'label': label, 'segment_name': segment_name}

def extract_segment_names_and_labels(self):
segment_names = []
labels = {}

for UID in self.UIDs:
label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv")
if os.path.exists(label_file):
label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label'])
label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0])
for idx, segment_name in enumerate(label_segment_names):
if segment_name not in segment_names:
segment_names.append(segment_name)
labels[segment_name] = label_data['label'].values[idx]

return segment_names, labels

def load_data(self, segment_name):
data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0])
seg_path = os.path.join(data_path_UID, segment_name + '_filt.csv')

try:
if self.data_format == 'csv' and seg_path.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif self.data_format == 'png' and seg_path.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
raise ValueError("Unsupported file format")

if self.standardize:
time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data

return time_freq_tensor.clone()

except Exception as e:
print(f"Error processing segment: {segment_name}. Exception: {str(e)}")
return torch.zeros((1, 128, 128)) # Return zeros in case of an error

def standard_scaling(self, data):
scaler = StandardScaler()
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)
return torch.Tensor(data)

def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv'):
dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
return dataloader

# To validate the len of the dataloader
# number of examples
# len(train_loader)
# number of batches
# len(train_loader)
2 changes: 1 addition & 1 deletion project_NVIL.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def extract_segment_names_and_labels(self):

def load_data(self, segment_name):
data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0])
seg_path = os.path.join(data_path_UID, segment_name + '_filt.csv')
seg_path = os.path.join(data_path_UID, segment_name + '_filt_STFT.csv')

try:
if self.data_format == 'csv' and seg_path.endswith('.csv'):
Expand Down

0 comments on commit 123c76f

Please sign in to comment.