Skip to content

Commit

Permalink
lstm baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
yul19079 committed Jan 26, 2024
1 parent 84368ae commit 74a907c
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 0 deletions.
Empty file removed model.ipynb
Empty file.
63 changes: 63 additions & 0 deletions model_base.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import h5py\n",
"import numpy as np\n",
"from datetime import datetime\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import keras\n",
"from keras import layers\n",
"from keras import models\n",
"\n",
"spatio_feature = None\n",
"\n",
"input_shape = (5,1)\n",
"\n",
"left_inputs = layers.Input(shape = input_shape)\n",
"\n",
"x = models.layers.LSTM(50, activation = \"tanh\")(left_inputs)\n",
"\n",
"x = layers.Dense(2*16*8, activation = \"linear\")(x)\n",
"\n",
"output = layers.Reshape((2,16,8))(x)\n",
"\n",
"model = models.Model([left_inputs],outputs = output)\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py38",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
279 changes: 279 additions & 0 deletions model_v1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import h5py\n",
"import numpy as np\n",
"from datetime import datetime\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"with h5py.File('./data/TrainingData.h5', 'r') as f:\n",
" # Access the trip dataset and their corresponding timestamps\n",
" traffic_data = f['trip'][()]\n",
" dates = f['timeslot'][()]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"formatted_dates = []\n",
"\n",
"for date_string in dates:\n",
" formatted_date = datetime.strptime(date_string.decode(), '%Y%m%d%H%M')\n",
"\n",
" year = formatted_date.year\n",
" month = formatted_date.month\n",
" day = formatted_date.day\n",
" hour = formatted_date.hour\n",
" minute = formatted_date.minute\n",
"\n",
" formatted_dates.append(np.array([year, month, day, hour, minute]))\n",
"\n",
"formatted_dates = np.array(formatted_dates).reshape(1488, 5, 1)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"test_size = 240\n",
"\n",
"train_traffic_data = traffic_data[:-test_size]\n",
"test_traffic_data = traffic_data[-test_size:]\n",
"\n",
"train_formatted_dates = formatted_dates[:-test_size]\n",
"test_formatted_dates = formatted_dates[-test_size:]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test = train_formatted_dates, test_formatted_dates\n",
"y_train, y_test = train_traffic_data, test_traffic_data"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"import keras\n",
"from keras import layers\n",
"from keras import models\n",
"\n",
"spatio_feature = None\n",
"\n",
"input_shape = (5,1)\n",
"\n",
"left_inputs = layers.Input(shape = input_shape)\n",
"\n",
"x = layers.LSTM(50, activation = \"tanh\")(left_inputs)\n",
"\n",
"x = layers.Dense(2*16*8, activation = \"linear\")(x)\n",
"\n",
"output = layers.Reshape((2,16,8))(x)\n",
"\n",
"model = models.Model([left_inputs],outputs = output)\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/50\n",
"39/39 [==============================] - 1s 9ms/step - loss: 212.9115 - val_loss: 144.5777\n",
"Epoch 2/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 183.6068 - val_loss: 118.8507\n",
"Epoch 3/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 158.2998 - val_loss: 105.8254\n",
"Epoch 4/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 143.8591 - val_loss: 98.0465\n",
"Epoch 5/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 134.1396 - val_loss: 93.0498\n",
"Epoch 6/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 127.2775 - val_loss: 89.7777\n",
"Epoch 7/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 122.2552 - val_loss: 87.7291\n",
"Epoch 8/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 118.6310 - val_loss: 86.4313\n",
"Epoch 9/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 115.9718 - val_loss: 85.7162\n",
"Epoch 10/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 114.0570 - val_loss: 85.3153\n",
"Epoch 11/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 112.6366 - val_loss: 85.1707\n",
"Epoch 12/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 111.6274 - val_loss: 85.1579\n",
"Epoch 13/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 110.8910 - val_loss: 85.2295\n",
"Epoch 14/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 110.3705 - val_loss: 85.3980\n",
"Epoch 15/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 109.9801 - val_loss: 85.5780\n",
"Epoch 16/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 109.7158 - val_loss: 85.8827\n",
"Epoch 17/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 109.5257 - val_loss: 86.0089\n",
"Epoch 18/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 109.4247 - val_loss: 86.1505\n",
"Epoch 19/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 109.0258 - val_loss: 85.7924\n",
"Epoch 20/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 108.8352 - val_loss: 85.9186\n",
"Epoch 21/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 108.5763 - val_loss: 85.5510\n",
"Epoch 22/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 108.1767 - val_loss: 85.3380\n",
"Epoch 23/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 107.4323 - val_loss: 84.9938\n",
"Epoch 24/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 106.1936 - val_loss: 82.6167\n",
"Epoch 25/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 104.6931 - val_loss: 81.7814\n",
"Epoch 26/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 102.9895 - val_loss: 79.9620\n",
"Epoch 27/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 101.5745 - val_loss: 78.4792\n",
"Epoch 28/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 98.5761 - val_loss: 76.5478\n",
"Epoch 29/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 94.6209 - val_loss: 72.0995\n",
"Epoch 30/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 90.4683 - val_loss: 70.5429\n",
"Epoch 31/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 86.6393 - val_loss: 67.5324\n",
"Epoch 32/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 83.6719 - val_loss: 66.9005\n",
"Epoch 33/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 81.2981 - val_loss: 65.7497\n",
"Epoch 34/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 79.8582 - val_loss: 65.4873\n",
"Epoch 35/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 78.4644 - val_loss: 64.4355\n",
"Epoch 36/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 77.4534 - val_loss: 65.0040\n",
"Epoch 37/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 76.9951 - val_loss: 65.3200\n",
"Epoch 38/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 76.4038 - val_loss: 64.2204\n",
"Epoch 39/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 75.8992 - val_loss: 64.0133\n",
"Epoch 40/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 75.4885 - val_loss: 64.4360\n",
"Epoch 41/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 75.2467 - val_loss: 64.6135\n",
"Epoch 42/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 74.6445 - val_loss: 64.3743\n",
"Epoch 43/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 74.2044 - val_loss: 63.9837\n",
"Epoch 44/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 73.8331 - val_loss: 64.3629\n",
"Epoch 45/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 73.4110 - val_loss: 63.6357\n",
"Epoch 46/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 72.9355 - val_loss: 63.7812\n",
"Epoch 47/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 72.2754 - val_loss: 63.3691\n",
"Epoch 48/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 71.8866 - val_loss: 64.9375\n",
"Epoch 49/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 71.0808 - val_loss: 63.7018\n",
"Epoch 50/50\n",
"39/39 [==============================] - 0s 2ms/step - loss: 70.5971 - val_loss: 62.9802\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x1b1db2e4a90>"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.compile(optimizer='adam', loss='mse')\n",
"model.fit(X_train, y_train, epochs=50, batch_size=32, validation_data=(X_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8/8 [==============================] - 0s 1000us/step - loss: 62.9802\n",
"Root Mean Squared Error: 7.936009544951329\n"
]
}
],
"source": [
"mse = model.evaluate(X_test, y_test)\n",
"\n",
"# Show rmse to see how model performs on the test set\n",
"rmse = np.sqrt(mse)\n",
"print(f'Root Mean Squared Error: {rmse}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py38",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 74a907c

Please sign in to comment.