Skip to content

error in file names ..._STFT.csv #11

Merged
merged 1 commit into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
"""
Created on Wed Dec 6 22:07:18 2023
@author: lrm22005
"""

import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from sklearn.preprocessing import StandardScaler
from PIL import Image

class CustomDataset(Dataset):
def __init__(self, data_path, labels_path, UIDs, standardize=True, data_format='csv'):
self.data_path = data_path
self.labels_path = labels_path
self.UIDs = UIDs
self.standardize = standardize
self.data_format = data_format
self.transforms = ToTensor()

# Extract unique segment names and their corresponding labels
self.segment_names, self.labels = self.extract_segment_names_and_labels()

def __len__(self):
return len(self.segment_names)

def __getitem__(self, idx):
segment_name = self.segment_names[idx]
label = self.labels[segment_name]

# Load data on-the-fly based on the segment_name
time_freq_tensor = self.load_data(segment_name)

return {'data': time_freq_tensor.unsqueeze(0), 'label': label, 'segment_name': segment_name}

def extract_segment_names_and_labels(self):
segment_names = []
labels = {}

for UID in self.UIDs:
label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv")
if os.path.exists(label_file):
label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label'])
label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0])
for idx, segment_name in enumerate(label_segment_names):
if segment_name not in segment_names:
segment_names.append(segment_name)
labels[segment_name] = label_data['label'].values[idx]

return segment_names, labels

def load_data(self, segment_name):
data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0])
seg_path = os.path.join(data_path_UID, segment_name + '_filt.csv')

try:
if self.data_format == 'csv' and seg_path.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif self.data_format == 'png' and seg_path.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
raise ValueError("Unsupported file format")

if self.standardize:
time_freq_tensor = self.standard_scaling(time_freq_tensor) # Standardize the data

return time_freq_tensor.clone()

except Exception as e:
print(f"Error processing segment: {segment_name}. Exception: {str(e)}")
return torch.zeros((1, 128, 128)) # Return zeros in case of an error

def standard_scaling(self, data):
scaler = StandardScaler()
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)
return torch.Tensor(data)

def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardize=False, data_format='csv'):
dataset = CustomDataset(data_path, labels_path, UIDs, standardize, data_format)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
return dataloader

# To validate the len of the dataloader
# number of examples
# len(train_loader)
# number of batches
# len(train_loader)
2 changes: 1 addition & 1 deletion project_NVIL.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def extract_segment_names_and_labels(self):

def load_data(self, segment_name):
data_path_UID = os.path.join(self.data_path, segment_name.split('_')[0])
seg_path = os.path.join(data_path_UID, segment_name + '_filt.csv')
seg_path = os.path.join(data_path_UID, segment_name + '_filt_STFT.csv')

try:
if self.data_format == 'csv' and seg_path.endswith('.csv'):
Expand Down