diff --git a/project_1.py b/project_1.py index 1c645ca..d13455a 100644 --- a/project_1.py +++ b/project_1.py @@ -47,6 +47,10 @@ def get_data_paths(data_format, is_linux=False, is_hpc=False): return data_path, labels_path, saving_path +# Create a logger +logger = logging.getLogger(__name) +logging.basicConfig(filename='error_log.txt', level=logging.ERROR) + # Standardize the data def standard_scaling(data): scaler = StandardScaler() @@ -75,27 +79,124 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr for seg in dir_list_seg[:len(dir_list_seg)]: # Limiting to 50 segments seg_path = os.path.join(data_path_UID, seg) - if data_format == 'csv' and seg.endswith('.csv'): - time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) - time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) - elif data_format == 'png' and seg.endswith('.png'): - img = Image.open(seg_path) - img_data = np.array(img) - time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) - else: - continue # Skip other file formats + try: + if data_format == 'csv' and seg.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif data_format == 'png' and seg.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + continue # Skip other file formats + + X_data.append(time_freq_tensor) + X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data - X_data.append(time_freq_tensor) - X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data + segment_names.append(seg.split('_filt')[0]) # Extract and store segment names - segment_names.append(seg.split('_filt')[0]) # Extract and store segment names + except Exception as e: + logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}") + # You can also add more information to the error log, such as the value of time_freq_plot. + continue # Continue to the next segment X_data = torch.cat(X_data, 0) X_data_original = torch.cat(X_data_original, 0) if standardize: X_data = standard_scaling(X_data) # Standardize the data + # Extract labels from CSV files + labels = extract_labels(UID_list, labels_path, segment_names) + + important_labels = [0.0, 1.0, 2.0, 3.0] # List of important labels + + # Initialize labels for segments as unlabeled (-1) + segment_labels = {segment_name: -1 for segment_name in segment_names} + + for UID in labels.keys(): + if UID not in UID_list: + # Skip UIDs that are not in the dataset + continue + + label_data, label_segment_names = labels[UID] + + for idx, segment_label in enumerate(label_data): + segment_name = label_segment_names[idx] + if segment_label in important_labels: + segment_labels[segment_name] = segment_label + else: + # Set labels that are not in the important list as -1 (Unlabeled) + segment_labels[segment_name] = -1 + + # Return all segments along with labels + if return_all: + return X_data_original, X_data, segment_names, segment_labels, segment_labels.values() + + # Filter out segments that are unlabeled (-1) + filtered_segment_names = [segment_name for segment_name, label in segment_labels.items() if label != -1] + + # Filter data to match the filtered segment names + filtered_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in filtered_segment_names]) + + # Return labeled and unlabeled segments along with labels + if return_all == 'labeled': + return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values() + # Return unlabeled segments along with labels + if return_all == 'unlabeled': + unlabeled_segment_names = [segment_name for segment_name, label in segment_labels.items() if label == -1] + unlabeled_data = torch.stack([X_data[segment_names.index(segment_name)] for segment_name in unlabeled_segment_names]) + return X_data_original, unlabeled_data, unlabeled_segment_names, {seg: segment_labels[seg] for seg in unlabeled_segment_names}, {seg: segment_labels[seg] for seg in unlabeled_segment_names}.values() + + # By default, return only labeled segments along with labels + return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values() + +def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=True, data_format='csv', return_all=False): + if data_format not in ['csv', 'png']: + raise ValueError("Invalid data_format. Choose 'csv' or 'png.") + + dir_list_UID = os.listdir(data_path) + UID_list = dir_list_UID[:dataset_size] if train else dir_list_UID[dataset_size:] + + X_data = [] # Store all data + X_data_original = [] # Store original data without standardization + segment_names = [] + + for UID in UID_list: + data_path_UID = os.path.join(data_path, UID) + dir_list_seg = os.listdir(data_path_UID) + + for seg in dir_list_seg[:len(dir_list_seg)]: # Limiting to 50 segments + seg_path = os.path.join(data_path_UID, seg) + + try: + if data_format == 'csv' and seg.endswith('.csv'): + time_freq_plot = np.array(pd.read_csv(seg_path, header=None)) + time_freq_tensor = torch.Tensor(time_freq_plot).reshape(1, 128, 128) + elif data_format == 'png' and seg.endswith('.png'): + img = Image.open(seg_path) + img_data = np.array(img) + time_freq_tensor = torch.Tensor(img_data).unsqueeze(0) + else: + continue # Skip other file formats + + X_data.append(time_freq_tensor) + X_data_original.append(time_freq_tensor.clone()) # Store a copy of the original data + + segment_names.append(seg.split('_filt')[0]) # Extract and store segment names + + except Exception as e: + logger.error(f"Error processing segment: {seg} in UID: {UID}. Exception: {str(e)}") + logger.error(f"Error processing segment: {time_freq_plot.size()} in UID: {UID}. Exception: {str(e)}") + # You can also add more information to the error log, such as the value of time_freq_plot. + continue # Continue to the next segment + + X_data = torch.cat(X_data, 0) + X_data_original = torch.cat(X_data_original, 0) + + if standardize: + X_data = standard_scaling(X_data) # Standardize the data + # Extract labels from CSV files labels = extract_labels(UID_list, labels_path, segment_names) @@ -143,6 +244,7 @@ def load_data(data_path, labels_path, dataset_size=2, train=True, standardize=Tr return X_data_original, filtered_data, filtered_segment_names, {seg: segment_labels[seg] for seg in filtered_segment_names}, {seg: segment_labels[seg] for seg in filtered_segment_names}.values() + def extract_labels(UID_list, labels_path, segment_names): labels = {} for UID in UID_list: