From 3f4624f30ca3319e2d919d0a45ba9a46c9d35752 Mon Sep 17 00:00:00 2001 From: yul19079 Date: Thu, 1 Feb 2024 09:14:29 -0500 Subject: [PATCH] save changes --- EDA.ipynb | 34 +-- lab.ipynb | 686 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 705 insertions(+), 15 deletions(-) diff --git a/EDA.ipynb b/EDA.ipynb index 3e7c33c..6ef4bb5 100644 --- a/EDA.ipynb +++ b/EDA.ipynb @@ -29,11 +29,21 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "raw_data = pd.read_pickle(\"./data/raw_feature.pkl\")" + "import pandas as pd\n", + "import numpy as np\n", + "raw_data = pd.read_pickle(\"./data/raw_feature.pkl\")\n", + "X_train = pd.read_pickle(\"./data/X_train.pkl\")\n", + "y_train = pd.read_pickle(\"./data/y_train.pkl\")\n", + "X_test = pd.read_pickle(\"./data/X_test.pkl\")\n", + "y_test = pd.read_pickle(\"./data/y_test.pkl\")\n", + "\n", + "\n", + "trip_avg_in = y_train.mean(axis = 0)[0]\n", + "trip_avg_out = y_train.mean(axis=0)[1]" ] }, { @@ -41,20 +51,14 @@ "execution_count": 4, "metadata": {}, "outputs": [], - "source": [ - "X_train = pd.read_pickle(\"./data/X_train.pkl\")\n", - "y_train = pd.read_pickle(\"./data/y_train.pkl\")" - ] + "source": [] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], - "source": [ - "trip_avg_in = y_train.mean(axis = 0)[0]\n", - "trip_avg_out = y_train.mean(axis=0)[1]" - ] + "source": [] }, { "cell_type": "code", @@ -64,7 +68,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -95,7 +99,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 7, @@ -417,7 +421,7 @@ "outputs": [], "source": [ "import sys\n", - "file_path = \"pearson_output.txt\"\n", + "file_path = \"pearson_output_filter0.txt\"\n", "\n", "with open(file_path,\"w\") as file:\n", " original_stdout = sys.stdout\n", @@ -427,7 +431,7 @@ " feature_name_list\n", " for feature_name in feature_name_list:\n", " temp_feature = get_feature_in_2d(raw_data,feature_name)\n", - " cal_pearson(temp_feature,trip_avg_in,log_text=feature_name)\n", + " cal_pearson(temp_feature,trip_avg_in,log_text=feature_name,filter_zero=True)\n", " print(\"\")\n", " sys.stdout = original_stdout" ] @@ -451,7 +455,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 12, diff --git a/lab.ipynb b/lab.ipynb index e69de29..31d8e71 100644 --- a/lab.ipynb +++ b/lab.ipynb @@ -0,0 +1,686 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 379, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "raw_data = pd.read_pickle(\"./data/raw_feature.pkl\")\n", + "X_train = pd.read_pickle(\"./data/X_train.pkl\")\n", + "y_train = pd.read_pickle(\"./data/y_train.pkl\")\n", + "X_test = pd.read_pickle(\"./data/X_test.pkl\")\n", + "y_test = pd.read_pickle(\"./data/y_test.pkl\")\n", + "\n", + "\n", + "trip_avg_in = y_train.mean(axis = 0)[0]\n", + "trip_avg_out = y_train.mean(axis=0)[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 380, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1248, 2, 16, 8)" + ] + }, + "execution_count": 380, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_train.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 381, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.preprocessing import StandardScaler,MinMaxScaler\n", + "std = StandardScaler()\n", + "train_shape = y_train.shape\n", + "test_shape = y_test.shape\n", + "y_train = std.fit_transform(np.reshape(y_train,(-1,y_train.shape[-2]*y_train.shape[-1])))\n", + "y_train = y_train.reshape(train_shape)\n", + "y_test = std.transform(np.reshape(y_test,(-1,y_test.shape[-2]*y_test.shape[-1])))\n", + "y_test = y_test.reshape(test_shape)\n", + "\n", + "# std.transform(y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 382, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " -0.85896564, -0.76920397],\n", + " [ 0. , -0.64681111, -0.72444045, ..., -0.79249501,\n", + " -0.69425498, -0.88850136],\n", + " [ 0. , -0.67416004, -0.96678391, ..., -0.52298898,\n", + " -0.9914674 , -0.60389946],\n", + " ...,\n", + " [-0.87073037, -0.71904056, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.64317053, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.81526382,\n", + " -0.65719543, -0.30593534]],\n", + "\n", + " [[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " -0.85896564, -0.76920397],\n", + " [ 0. , -0.64681111, -1.0379836 , ..., -0.71802238,\n", + " -0.69425498, -0.88850136],\n", + " [ 0. , -0.67416004, -0.77548118, ..., -0.52298898,\n", + " -0.77921055, -0.60389946],\n", + " ...,\n", + " [-0.87073037, -0.57282481, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.64317053, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.81526382,\n", + " -0.65719543, -0.30593534]]],\n", + "\n", + "\n", + " [[[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " -0.85896564, -0.9306107 ],\n", + " [ 0. , -0.64681111, -1.0379836 , ..., -0.71802238,\n", + " -0.89901221, -0.88850136],\n", + " [ 0. , -0.85536868, -0.77548118, ..., -0.60752339,\n", + " -0.9914674 , -0.60389946],\n", + " ...,\n", + " [-0.87073037, -0.79214844, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.78565332, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.38322592,\n", + " -0.65719543, -0.30593534]],\n", + "\n", + " [[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " -0.85896564, -0.9306107 ],\n", + " [ 0. , -0.64681111, -1.0379836 , ..., -0.79249501,\n", + " -0.69425498, -0.88850136],\n", + " [ 0. , -0.85536868, -0.96678391, ..., -0.60752339,\n", + " -0.9914674 , -0.60389946],\n", + " ...,\n", + " [-0.87073037, -0.79214844, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.78565332, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.81526382,\n", + " -0.65719543, -0.30593534]]],\n", + "\n", + "\n", + " [[[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " -0.85896564, -0.9306107 ],\n", + " [ 0. , -0.64681111, -0.82895483, ..., -0.79249501,\n", + " -0.89901221, -0.88850136],\n", + " [ 0. , -0.85536868, -0.96678391, ..., -0.60752339,\n", + " -0.67308212, -0.60389946],\n", + " ...,\n", + " [-0.87073037, -0.79214844, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.78565332, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.81526382,\n", + " -0.65719543, -0.30593534]],\n", + "\n", + " [[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " -0.85896564, -0.9306107 ],\n", + " [ 0. , -0.64681111, -0.93346922, ..., -0.71802238,\n", + " -0.89901221, -0.88850136],\n", + " [ 0. , -0.31174275, -0.77548118, ..., -0.52298898,\n", + " -0.9914674 , -0.60389946],\n", + " ...,\n", + " [-0.77718082, -0.57282481, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.78565332, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.81526382,\n", + " -0.65719543, -0.30593534]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " -0.49766159, -0.9306107 ],\n", + " [ 0. , -0.64681111, -0.51541168, ..., -0.49460449,\n", + " -0.89901221, 0.29553462],\n", + " [ 0. , -0.85536868, -0.96678391, ..., -0.52298898,\n", + " -0.77921055, -0.60389946],\n", + " ...,\n", + " [-0.68363128, -0.71904056, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.78565332, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.81526382,\n", + " -0.65719543, -0.30593534]],\n", + "\n", + " [[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " -0.85896564, -0.76920397],\n", + " [ 0. , -0.5325468 , 0.00716024, ..., -0.56907712,\n", + " -0.48949775, -0.88850136],\n", + " [ 0. , -0.85536868, -0.77548118, ..., -0.43845456,\n", + " -0.67308212, -0.60389946],\n", + " ...,\n", + " [-0.87073037, -0.49971693, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.50068774, 0. , -0.026426 , ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.81526382,\n", + " -0.65719543, -0.30593534]]],\n", + "\n", + "\n", + " [[[ 0. , 0. , -0.69192794, ..., -0.48802371,\n", + " -0.49766159, -0.4463905 ],\n", + " [ 0. , -0.64681111, 0.11167462, ..., -0.34565923,\n", + " -0.48949775, -0.09914404],\n", + " [ 0. , -0.4929514 , 0.3723352 , ..., -0.60752339,\n", + " -0.5669537 , -0.60389946],\n", + " ...,\n", + " [-0.77718082, -0.64593269, 0. , ..., -0.19596545,\n", + " 0. , 0.78970762],\n", + " [-0.78565332, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.3349281 , ..., -0.81526382,\n", + " -0.65719543, -0.30593534]],\n", + "\n", + " [[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " 0.22494651, 0.0378297 ],\n", + " [ 0. , -0.41828248, 0.00716024, ..., -0.56907712,\n", + " -0.48949775, -0.88850136],\n", + " [ 0. , -0.67416004, -0.39287572, ..., -0.43845456,\n", + " -0.5669537 , 0.40977442],\n", + " ...,\n", + " [-0.87073037, -0.71904056, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.50068774, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.3349281 , ..., -0.38322592,\n", + " -0.65719543, -0.30593534]]],\n", + "\n", + "\n", + " [[[ 0. , 0. , -0.69192794, ..., -0.73855994,\n", + " -0.85896564, -0.9306107 ],\n", + " [ 0. , -0.64681111, -0.82895483, ..., -0.71802238,\n", + " -0.48949775, -0.88850136],\n", + " [ 0. , -0.67416004, -0.96678391, ..., -0.35392014,\n", + " -0.88533898, -0.60389946],\n", + " ...,\n", + " [-0.68363128, -0.79214844, 0. , ..., -0.19596545,\n", + " 0. , 0.78970762],\n", + " [-0.78565332, 0. , -0.026426 , ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.81526382,\n", + " -0.65719543, -0.30593534]],\n", + "\n", + " [[ 0. , 0. , -0.69192794, ..., -0.23748747,\n", + " -0.85896564, -0.76920397],\n", + " [ 0. , -0.5325468 , -0.61992607, ..., -0.71802238,\n", + " -0.69425498, -0.88850136],\n", + " [ 0. , -0.85536868, -0.01027026, ..., -0.60752339,\n", + " -0.77921055, -0.60389946],\n", + " ...,\n", + " [-0.59008174, -0.64593269, 0. , ..., -0.19596545,\n", + " 0. , -0.51739465],\n", + " [-0.78565332, 0. , -0.66993143, ..., -0.39701126,\n", + " 0. , 0. ],\n", + " [ 0. , 0. , -0.77124363, ..., -0.38322592,\n", + " -0.65719543, -0.30593534]]]])" + ] + }, + "execution_count": 382, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_train" + ] + }, + { + "cell_type": "code", + "execution_count": 383, + "metadata": {}, + "outputs": [], + "source": [ + "from numpy.lib.stride_tricks import as_strided\n", + "\n", + "def create_rolling_window(matrix,t):\n", + " matrix_shape = matrix.shape\n", + " return_length = matrix_shape[0] - t \n", + " dataset = tf.data.Dataset.from_tensor_slices(matrix)\n", + " windows = dataset.window(t,shift = 1,drop_remainder=True)\n", + " windows = windows.take(return_length)\n", + " windows = windows.flat_map(lambda window: window.batch(t))\n", + "\n", + " \n", + " return windows\n", + "def create_result_ds(matrix,delay):\n", + " dataset = tf.data.Dataset.from_tensor_slices(matrix)\n", + " dataset = dataset.skip(delay)\n", + " return dataset\n", + "\n", + "def combine_ds(ds1,ds2):\n", + " combined_ds = tf.data.Dataset.zip(((ds1),ds2))\n", + " combined_ds = combined_ds.batch(batch_size=32)\n", + " return combined_ds\n", + "\n", + "def gen_train_ds(X_train,y_train,step_len):\n", + " X = create_rolling_window(y_train,step_len)\n", + " y = create_result_ds(y_train,step_len)\n", + " train_ds = combine_ds(X,y)\n", + " return train_ds\n", + "\n", + "def gen_test_ds(X_test,y_test,step_len):\n", + " X = create_rolling_window(y_test,step_len)\n", + " y = create_result_ds(y_test,step_len)\n", + " test_ds = combine_ds(X,y)\n", + " return test_ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 384, + "metadata": {}, + "outputs": [], + "source": [ + "time_step = 48\n", + "train_ds = gen_train_ds(X_train,y_train,step_len=time_step)\n", + "test_ds = gen_test_ds(X_test,y_test,step_len=time_step)" + ] + }, + { + "cell_type": "code", + "execution_count": 385, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# ds = tf.data.Dataset.from_tensor_slices(((X_train,X_train),y_train))\n", + "# test_ds = tf.data.Dataset.from_tensor_slices(((X_test,X_test),y_test))\n", + "# for batch, ((feature1,feature2), labels) in enumerate(ds):\n", + "# print(feature1.shape,feature2.shape,labels.shape)\n", + "# print(labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 386, + "metadata": {}, + "outputs": [], + "source": [ + "import keras\n", + "from keras import layers\n", + "from keras import models\n", + "\n", + "spatio_feature = None\n", + "\n", + "input_shape = (time_step,2,16,8)\n", + "\n", + "inputs = layers.Input(shape = input_shape)\n", + "\n", + "\n", + "x = layers.Reshape((time_step,2*16*8))(inputs)\n", + "\n", + "x = layers.LSTM(100, activation = \"relu\")(x)\n", + "\n", + "x = layers.Dense(2*16*8, activation = \"relu\")(x)\n", + "\n", + "output = layers.Reshape((2,16,8))(x)\n", + "\n", + "model = models.Model([inputs],outputs = output)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 387, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_26\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_37 (InputLayer) [(None, 48, 2, 16, 8)] 0 \n", + " \n", + " reshape_41 (Reshape) (None, 48, 256) 0 \n", + " \n", + " lstm_29 (LSTM) (None, 100) 142800 \n", + " \n", + " dense_27 (Dense) (None, 256) 25856 \n", + " \n", + " reshape_42 (Reshape) (None, 2, 16, 8) 0 \n", + " \n", + "=================================================================\n", + "Total params: 168656 (658.81 KB)\n", + "Trainable params: 168656 (658.81 KB)\n", + "Non-trainable params: 0 (0.00 Byte)\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 388, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/50\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "38/38 [==============================] - 2s 33ms/step - loss: 0.6559 - val_loss: 0.4928\n", + "Epoch 2/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.5243 - val_loss: 0.4425\n", + "Epoch 3/50\n", + "38/38 [==============================] - 1s 24ms/step - loss: 0.4840 - val_loss: 0.4237\n", + "Epoch 4/50\n", + "38/38 [==============================] - 1s 25ms/step - loss: 0.4592 - val_loss: 0.4084\n", + "Epoch 5/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.4430 - val_loss: 0.4001\n", + "Epoch 6/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.4327 - val_loss: 0.3982\n", + "Epoch 7/50\n", + "38/38 [==============================] - 1s 25ms/step - loss: 0.4230 - val_loss: 0.3922\n", + "Epoch 8/50\n", + "38/38 [==============================] - 1s 24ms/step - loss: 0.4190 - val_loss: 0.3961\n", + "Epoch 9/50\n", + "38/38 [==============================] - 1s 24ms/step - loss: 0.4144 - val_loss: 0.3938\n", + "Epoch 10/50\n", + "38/38 [==============================] - 1s 25ms/step - loss: 0.4146 - val_loss: 0.3889\n", + "Epoch 11/50\n", + "38/38 [==============================] - 1s 28ms/step - loss: 0.4071 - val_loss: 0.3866\n", + "Epoch 12/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.4029 - val_loss: 0.3874\n", + "Epoch 13/50\n", + "38/38 [==============================] - 1s 25ms/step - loss: 0.3994 - val_loss: 0.3860\n", + "Epoch 14/50\n", + "38/38 [==============================] - 1s 23ms/step - loss: 0.3962 - val_loss: 0.3839\n", + "Epoch 15/50\n", + "38/38 [==============================] - 1s 24ms/step - loss: 0.3926 - val_loss: 0.3825\n", + "Epoch 16/50\n", + "38/38 [==============================] - 1s 24ms/step - loss: 0.3908 - val_loss: 0.3840\n", + "Epoch 17/50\n", + "38/38 [==============================] - 1s 25ms/step - loss: 0.3883 - val_loss: 0.3828\n", + "Epoch 18/50\n", + "38/38 [==============================] - 1s 25ms/step - loss: 0.3880 - val_loss: 0.3812\n", + "Epoch 19/50\n", + "38/38 [==============================] - 1s 25ms/step - loss: 0.3890 - val_loss: 0.3849\n", + "Epoch 20/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.3894 - val_loss: 0.3904\n", + "Epoch 21/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.3803 - val_loss: 0.3805\n", + "Epoch 22/50\n", + "38/38 [==============================] - 1s 25ms/step - loss: 0.3778 - val_loss: 0.3776\n", + "Epoch 23/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3758 - val_loss: 0.3819\n", + "Epoch 24/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.3737 - val_loss: 0.3820\n", + "Epoch 25/50\n", + "38/38 [==============================] - 1s 28ms/step - loss: 0.3706 - val_loss: 0.3769\n", + "Epoch 26/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3680 - val_loss: 0.3805\n", + "Epoch 27/50\n", + "38/38 [==============================] - 1s 28ms/step - loss: 0.3659 - val_loss: 0.3804\n", + "Epoch 28/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3638 - val_loss: 0.3784\n", + "Epoch 29/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3620 - val_loss: 0.3832\n", + "Epoch 30/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3605 - val_loss: 0.3790\n", + "Epoch 31/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.3585 - val_loss: 0.3794\n", + "Epoch 32/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3575 - val_loss: 0.3793\n", + "Epoch 33/50\n", + "38/38 [==============================] - 1s 28ms/step - loss: 0.3560 - val_loss: 0.3819\n", + "Epoch 34/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3532 - val_loss: 0.3840\n", + "Epoch 35/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3516 - val_loss: 0.3804\n", + "Epoch 36/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3517 - val_loss: 0.3809\n", + "Epoch 37/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3525 - val_loss: 0.3801\n", + "Epoch 38/50\n", + "38/38 [==============================] - 1s 28ms/step - loss: 0.3501 - val_loss: 0.3810\n", + "Epoch 39/50\n", + "38/38 [==============================] - 1s 29ms/step - loss: 0.3490 - val_loss: 0.3821\n", + "Epoch 40/50\n", + "38/38 [==============================] - 1s 29ms/step - loss: 0.3456 - val_loss: 0.3821\n", + "Epoch 41/50\n", + "38/38 [==============================] - 1s 29ms/step - loss: 0.3435 - val_loss: 0.3792\n", + "Epoch 42/50\n", + "38/38 [==============================] - 1s 28ms/step - loss: 0.3421 - val_loss: 0.3797\n", + "Epoch 43/50\n", + "38/38 [==============================] - 1s 28ms/step - loss: 0.3412 - val_loss: 0.3801\n", + "Epoch 44/50\n", + "38/38 [==============================] - 1s 28ms/step - loss: 0.3380 - val_loss: 0.3819\n", + "Epoch 45/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3373 - val_loss: 0.3808\n", + "Epoch 46/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.3344 - val_loss: 0.3828\n", + "Epoch 47/50\n", + "38/38 [==============================] - 1s 27ms/step - loss: 0.3330 - val_loss: 0.3923\n", + "Epoch 48/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.3336 - val_loss: 0.3863\n", + "Epoch 49/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.3338 - val_loss: 0.3826\n", + "Epoch 50/50\n", + "38/38 [==============================] - 1s 26ms/step - loss: 0.3314 - val_loss: 0.3867\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 388, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.compile(optimizer='adam', loss='mse')\n", + "# model.fit([X_train,X_train],y_train,epochs=50, batch_size=10, validation_data=([X_test,X_test], y_test))\n", + "# model.fit(ds,epochs=50, batch_size=10, validation_data=(test_ds),shuffle=False)\n", + "\n", + "model.fit(train_ds,epochs=50,shuffle=False,validation_data=test_ds)" + ] + }, + { + "cell_type": "code", + "execution_count": 389, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 1/Unknown - 0s 25ms/step - loss: 0.4388" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6/6 [==============================] - 0s 11ms/step - loss: 0.3867\n" + ] + }, + { + "data": { + "text/plain": [ + "0.3866962194442749" + ] + }, + "execution_count": 389, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.evaluate(test_ds)" + ] + }, + { + "cell_type": "code", + "execution_count": 390, + "metadata": {}, + "outputs": [], + "source": [ + "y_test_check = [iter(test_ds)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 391, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "in user code:\n\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\training.py\", line 2416, in predict_function *\n return step_function(self, iterator)\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\training.py\", line 2401, in step_function **\n outputs = model.distribute_strategy.run(run_step, args=(data,))\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\training.py\", line 2389, in run_step **\n outputs = model.predict_step(data)\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\training.py\", line 2357, in predict_step\n return self(x, training=False)\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\utils\\traceback_utils.py\", line 70, in error_handler\n raise e.with_traceback(filtered_tb) from None\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\input_spec.py\", line 298, in assert_input_compatibility\n raise ValueError(\n\n ValueError: Input 0 of layer \"model_26\" is incompatible with the layer: expected shape=(None, 48, 2, 16, 8), found shape=(None, 5, 1)\n", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[391], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m v \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_test\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32mc:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\utils\\traceback_utils.py:70\u001b[0m, in \u001b[0;36mfilter_traceback..error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 67\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[0;32m 68\u001b[0m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[0;32m 69\u001b[0m \u001b[38;5;66;03m# `tf.debugging.disable_traceback_filtering()`\u001b[39;00m\n\u001b[1;32m---> 70\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m 72\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n", + "File \u001b[1;32m~\\AppData\\Local\\Temp\\__autograph_generated_filer42i32m1.py:15\u001b[0m, in \u001b[0;36mouter_factory..inner_factory..tf__predict_function\u001b[1;34m(iterator)\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 14\u001b[0m do_return \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m---> 15\u001b[0m retval_ \u001b[38;5;241m=\u001b[39m ag__\u001b[38;5;241m.\u001b[39mconverted_call(ag__\u001b[38;5;241m.\u001b[39mld(step_function), (ag__\u001b[38;5;241m.\u001b[39mld(\u001b[38;5;28mself\u001b[39m), ag__\u001b[38;5;241m.\u001b[39mld(iterator)), \u001b[38;5;28;01mNone\u001b[39;00m, fscope)\n\u001b[0;32m 16\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[0;32m 17\u001b[0m do_return \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "\u001b[1;31mValueError\u001b[0m: in user code:\n\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\training.py\", line 2416, in predict_function *\n return step_function(self, iterator)\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\training.py\", line 2401, in step_function **\n outputs = model.distribute_strategy.run(run_step, args=(data,))\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\training.py\", line 2389, in run_step **\n outputs = model.predict_step(data)\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\training.py\", line 2357, in predict_step\n return self(x, training=False)\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\utils\\traceback_utils.py\", line 70, in error_handler\n raise e.with_traceback(filtered_tb) from None\n File \"c:\\Users\\yuyao\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\engine\\input_spec.py\", line 298, in assert_input_compatibility\n raise ValueError(\n\n ValueError: Input 0 of layer \"model_26\" is incompatible with the layer: expected shape=(None, 48, 2, 16, 8), found shape=(None, 5, 1)\n" + ] + } + ], + "source": [ + "v = model.predict(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(192, 2, 16, 8)" + ] + }, + "execution_count": 308, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "v.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)" + ] + }, + "execution_count": 316, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "v[2,0,:,:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}