diff --git a/project_1.py b/project_1.py index 26feb38..09a30a9 100644 --- a/project_1.py +++ b/project_1.py @@ -334,7 +334,7 @@ def main(): data_format = 'csv' # Choose 'csv' or 'png' data_path, labels_path, saving_path = get_data_paths(data_format, is_linux=is_linux, is_hpc=is_hpc) - original_data, standardized_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=10, train=True, data_format=data_format) + original_data, standardized_data, segment_names, labels = load_data(data_path, labels_path, dataset_size=10, train=False, data_format=data_format) # test_data, _, _ = load_data(data_path, labels_path, dataset_size=30, train=False) train_dataloader = create_dataloader(standardized_data)