Skip to content

Commit

Permalink
Merge pull request #5 from lrm22005/main
Browse files Browse the repository at this point in the history
Logging updated in loading

I found a error with a file based on the size, I like to ensure that this error is not associated with our data. 

The error was in 
    _, labeled_data, _, _, labeled_labels = load_data(data_path, labels_path, dataset_size=141, train=True, data_format=data_format, return_all="labeled")

and it is associated with the size of the data readed

                    time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
                    time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
  • Loading branch information
lrm22005 committed Oct 25, 2023
2 parents 1d4c1bb + 70db268 commit 2baed00
Showing 1 changed file with 114 additions and 12 deletions.
126 changes: 114 additions & 12 deletions project_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False):

return data_path, labels_path, saving_path

# Create a logger
logger = logging.getLogger(__name)
logging.basicConfig(filename='error_log.txt', level=logging.ERROR)

# Standardize the data
def standard_scaling(data):
scaler = StandardScaler()
Expand Down Expand Up @@ -75,27 +79,124 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr
for seg in dir_list_seg[:len(dir_list_seg)]: # Limiting to 50 segments
seg_path = os.path.join(data_path_UID, seg)

if data_format == 'csv' and seg.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif data_format == 'png' and seg.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
continue # Skip other file formats
try:
if data_format == 'csv' and seg.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif data_format == 'png' and seg.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
continue # Skip other file formats

X_data.append(time_freq_tensor)
X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data

X_data.append(time_freq_tensor)
X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data
segment_names.append(seg.split('_filt')[0]) # Extract and store segment names

segment_names.append(seg.split('_filt')[0]) # Extract and store segment names
except Exception as e:
logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}")
# You can also add more information to the error log, such as the value of time_freq_plot.
continue # Continue to the next segment

X_data = torch.cat(X_data, 0)
X_data_original = torch.cat(X_data_original, 0)

if standardize:
X_data = standard_scaling(X_data) # Standardize the data
# Extract labels from CSV files
labels = extract_labels(UID_list, labels_path, segment_names)

important_labels = [0.0, 1.0, 2.0, 3.0] # List of important labels

# Initialize labels for segments as unlabeled (-1)
segment_labels = {segment_name: -1 for segment_name in segment_names}

for UID in labels.keys():
if UID not in UID_list:
# Skip UIDs that are not in the dataset
continue

label_data, label_segment_names = labels[UID]

for idx, segment_label in enumerate(label_data):
segment_name = label_segment_names[idx]
if segment_label in important_labels:
segment_labels[segment_name] = segment_label
else:
# Set labels that are not in the important list as -1 (Unlabeled)
segment_labels[segment_name] = -1

# Return all segments along with labels
if return_all:
return X_data_original, X_data, segment_names, segment_labels, segment_labels.values()

# Filter out segments that are unlabeled (-1)
filtered_segment_names = [segment_name for segment_name, label in segment_labels.items() if label != -1]

# Filter data to match the filtered segment names
filtered_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in filtered_segment_names])

# Return labeled and unlabeled segments along with labels
if return_all == 'labeled':
return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values()

# Return unlabeled segments along with labels
if return_all == 'unlabeled':
unlabeled_segment_names = [segment_name for segment_name, label in segment_labels.items() if label == -1]
unlabeled_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in unlabeled_segment_names])
return X_data_original, unlabeled_data, unlabeled_segment_names, {seg: segment_labels[seg] for seg in unlabeled_segment_names}, {seg: segment_labels[seg] for seg in unlabeled_segment_names}.values()

# By default, return only labeled segments along with labels
return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values()

def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=True, data_format='csv', return_all=False):
if data_format not in ['csv', 'png']:
raise ValueError("Invalid data_format. Choose 'csv' or 'png.")

dir_list_UID = os.listdir(data_path)
UID_list = dir_list_UID[:dataset_size] if train else dir_list_UID[dataset_size:]

X_data = [] # Store all data
X_data_original = [] # Store original data without standardization
segment_names = []

for UID in UID_list:
data_path_UID = os.path.join(data_path, UID)
dir_list_seg = os.listdir(data_path_UID)

for seg in dir_list_seg[:len(dir_list_seg)]: # Limiting to 50 segments
seg_path = os.path.join(data_path_UID, seg)

try:
if data_format == 'csv' and seg.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif data_format == 'png' and seg.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
continue # Skip other file formats

X_data.append(time_freq_tensor)
X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data

segment_names.append(seg.split('_filt')[0]) # Extract and store segment names

except Exception as e:
logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}")
logger.error(f"Error processing segment: {time_freq_plot.size()} in UID: {UID}. Exception: {str(e)}")
# You can also add more information to the error log, such as the value of time_freq_plot.
continue # Continue to the next segment

X_data = torch.cat(X_data, 0)
X_data_original = torch.cat(X_data_original, 0)

if standardize:
X_data = standard_scaling(X_data) # Standardize the data

# Extract labels from CSV files
labels = extract_labels(UID_list, labels_path, segment_names)

Expand Down Expand Up @@ -143,6 +244,7 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr
return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values()



def extract_labels(UID_list, labels_path, segment_names):
labels = {}
for UID in UID_list:
Expand Down

0 comments on commit 2baed00

Please sign in to comment.