Skip to content

Logging updated in loading #5

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 114 additions & 12 deletions project_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False):

return data_path, labels_path, saving_path

# Create a logger
logger = logging.getLogger(__name)
logging.basicConfig(filename='error_log.txt', level=logging.ERROR)

# Standardize the data
def standard_scaling(data):
scaler = StandardScaler()
Expand Down Expand Up @@ -75,27 +79,124 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr
for seg in dir_list_seg[:len(dir_list_seg)]: # Limiting to 50 segments
seg_path = os.path.join(data_path_UID, seg)

if data_format == 'csv' and seg.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif data_format == 'png' and seg.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
continue # Skip other file formats
try:
if data_format == 'csv' and seg.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif data_format == 'png' and seg.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
continue # Skip other file formats

X_data.append(time_freq_tensor)
X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data

X_data.append(time_freq_tensor)
X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data
segment_names.append(seg.split('_filt')[0]) # Extract and store segment names

segment_names.append(seg.split('_filt')[0]) # Extract and store segment names
except Exception as e:
logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}")
# You can also add more information to the error log, such as the value of time_freq_plot.
continue # Continue to the next segment

X_data = torch.cat(X_data, 0)
X_data_original = torch.cat(X_data_original, 0)

if standardize:
X_data = standard_scaling(X_data) # Standardize the data
# Extract labels from CSV files
labels = extract_labels(UID_list, labels_path, segment_names)

important_labels = [0.0, 1.0, 2.0, 3.0] # List of important labels

# Initialize labels for segments as unlabeled (-1)
segment_labels = {segment_name: -1 for segment_name in segment_names}

for UID in labels.keys():
if UID not in UID_list:
# Skip UIDs that are not in the dataset
continue

label_data, label_segment_names = labels[UID]

for idx, segment_label in enumerate(label_data):
segment_name = label_segment_names[idx]
if segment_label in important_labels:
segment_labels[segment_name] = segment_label
else:
# Set labels that are not in the important list as -1 (Unlabeled)
segment_labels[segment_name] = -1

# Return all segments along with labels
if return_all:
return X_data_original, X_data, segment_names, segment_labels, segment_labels.values()

# Filter out segments that are unlabeled (-1)
filtered_segment_names = [segment_name for segment_name, label in segment_labels.items() if label != -1]

# Filter data to match the filtered segment names
filtered_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in filtered_segment_names])

# Return labeled and unlabeled segments along with labels
if return_all == 'labeled':
return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values()

# Return unlabeled segments along with labels
if return_all == 'unlabeled':
unlabeled_segment_names = [segment_name for segment_name, label in segment_labels.items() if label == -1]
unlabeled_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in unlabeled_segment_names])
return X_data_original, unlabeled_data, unlabeled_segment_names, {seg: segment_labels[seg] for seg in unlabeled_segment_names}, {seg: segment_labels[seg] for seg in unlabeled_segment_names}.values()

# By default, return only labeled segments along with labels
return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values()

def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=True, data_format='csv', return_all=False):
if data_format not in ['csv', 'png']:
raise ValueError("Invalid data_format. Choose 'csv' or 'png.")

dir_list_UID = os.listdir(data_path)
UID_list = dir_list_UID[:dataset_size] if train else dir_list_UID[dataset_size:]

X_data = [] # Store all data
X_data_original = [] # Store original data without standardization
segment_names = []

for UID in UID_list:
data_path_UID = os.path.join(data_path, UID)
dir_list_seg = os.listdir(data_path_UID)

for seg in dir_list_seg[:len(dir_list_seg)]: # Limiting to 50 segments
seg_path = os.path.join(data_path_UID, seg)

try:
if data_format == 'csv' and seg.endswith('.csv'):
time_freq_plot = np.array(pd.read_csv(seg_path, header=None))
time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128)
elif data_format == 'png' and seg.endswith('.png'):
img = Image.open(seg_path)
img_data = np.array(img)
time_freq_tensor = torch.Tensor(img_data).unsqueeze(0)
else:
continue # Skip other file formats

X_data.append(time_freq_tensor)
X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data

segment_names.append(seg.split('_filt')[0]) # Extract and store segment names

except Exception as e:
logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}")
logger.error(f"Error processing segment: {time_freq_plot.size()} in UID: {UID}. Exception: {str(e)}")
# You can also add more information to the error log, such as the value of time_freq_plot.
continue # Continue to the next segment

X_data = torch.cat(X_data, 0)
X_data_original = torch.cat(X_data_original, 0)

if standardize:
X_data = standard_scaling(X_data) # Standardize the data

# Extract labels from CSV files
labels = extract_labels(UID_list, labels_path, segment_names)

Expand Down Expand Up @@ -143,6 +244,7 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr
return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values()



def extract_labels(UID_list, labels_path, segment_names):
labels = {}
for UID in UID_list:
Expand Down