diff --git a/examples/energy_forecasting.ipynb b/examples/energy_forecasting.ipynb new file mode 100644 index 00000000..5e9e6d24 --- /dev/null +++ b/examples/energy_forecasting.ipynb @@ -0,0 +1,1203 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "5eb2166a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataTransformerRegistry.enable('vegafusion')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import altair as alt\n", + "import polars as pl\n", + "\n", + "# Needed for larger datasets\n", + "alt.data_transformers.enable(\"vegafusion\")\n", + "\n", + "# alt.renderers.enable(\"browser\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "15980cb0", + "metadata": {}, + "outputs": [], + "source": [ + "X = pl.read_csv(\"X_train.csv\")\n", + "y = pl.read_csv(\"Y_train.csv\")\n", + "X = X.with_columns(pl.col(\"Time\").str.to_datetime(\"%d/%m/%Y %H:%M\"))\n", + "X_y = X.join(y, on=\"ID\")\n", + "X_WF1 = X.filter(pl.col(\"WF\") == \"WF1\")\n", + "X_WF1.filter(pl.col(\"Time\").dt.hour() == 0)\n", + "X_WF1[\"Time\"].max()\n", + "X_y_WF1 = X_y.filter(pl.col(\"WF\") == \"WF1\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3b139cd", + "metadata": {}, + "outputs": [], + "source": [ + "alt.Chart(X_WF1).mark_point().encode(\n", + " x=\"Time\", y=\"NWP1_00h_D-2_U\", tooltip=alt.Tooltip(\"Time\", format=\"%H:%M\")\n", + ").interactive()\n", + "(\n", + " list(\n", + " X_WF1.filter(pl.col(\"NWP1_00h_D-2_U\").is_not_null())[\"Time\"].dt.hour().unique()\n", + " ),\n", + " list(\n", + " X_WF1.filter(pl.col(\"NWP1_06h_D-2_U\").is_not_null())[\"Time\"].dt.hour().unique()\n", + " ),\n", + " list(\n", + " X_WF1.filter(pl.col(\"NWP1_12h_D-2_U\").is_not_null())[\"Time\"].dt.hour().unique()\n", + " ),\n", + " list(\n", + " X_WF1.filter(pl.col(\"NWP1_18h_D-2_U\").is_not_null())[\"Time\"].dt.hour().unique()\n", + " ),\n", + ")\n", + "with pl.Config(tbl_rows=-1):\n", + " print(X_WF1)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "9e6a09e0", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.Chart(...)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Complementary data\n", + "X_comp = pl.read_csv(\"WindFarms_complementary_data.csv\", separator=\";\")\n", + "X_comp = X_comp.filter(pl.col(\"Time (UTC)\").is_not_null())\n", + "X_comp = X_comp.with_columns(pl.col(\"Time (UTC)\").str.to_datetime(\"%d/%m/%Y %H:%M\"))\n", + "\n", + "(\n", + " alt.Chart(\n", + " X_comp.filter(\n", + " (pl.col(\"Wind Farm\") == \"WF1\") & (pl.col(\"Wind Turbine\") == \"TE1\")\n", + " ).with_columns(\n", + " (pl.col(\"Wind direction (�)\") - pl.col(\"Nacelle direction (�)\")).alias(\n", + " \"Nacelle misalignment (deg)\"\n", + " )\n", + " )\n", + " )\n", + " .mark_point()\n", + " .encode(x=\"Time (UTC)\", y=\"Nacelle misalignment (deg)\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "7edeafcd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.Chart(...)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Statistics of power production\n", + "X_y_WF1[\"Production\"].describe()\n", + "\n", + "# Histogram\n", + "(\n", + " alt.Chart(X_y_WF1)\n", + " .mark_bar()\n", + " .encode(alt.X(\"Production\", bin=alt.Bin(step=0.5)), y=\"count()\")\n", + " .properties(width=800)\n", + ")\n", + "\n", + "# The distribution of production is heavily right skewed. The median is 0.82 MW.\n", + "# According to Our World in Data 2017 (https://ourworldindata.org/scale-for-electricity), a French person consumes 0.019 MWh/day" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2ea436d9", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.Chart(...)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Total production for the month\n", + "(\n", + " alt.Chart(X_y_WF1)\n", + " .mark_line()\n", + " .encode(x=\"yearmonth(Time):T\", y=\"sum(Production)\")\n", + " .properties(width=800)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e8ec30bb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.Chart(...)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "# There's a big drop in December 2018, compared to November and January. Is it because demand dropped, or because the data was corrupted, or because the wind farms were in maintenance?\n", + "\n", + "(\n", + " alt.Chart(X_y_WF1.filter(pl.col(\"Time\").dt.month() == 12))\n", + " .mark_point()\n", + " .encode(x=\"Time\", y=\"Production\")\n", + " .properties(width=3000)\n", + ")\n", + "\n", + "# The power production was very near zero for 9 consecutive days, from 12 December to 21 December." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "6db088ea", + "metadata": {}, + "outputs": [], + "source": [ + "# Trying out a classic sktime forecasting workflow\n", + "\n", + "import numpy as np\n", + "from sktime.forecasting.compose import make_reduction\n", + "from sklearn.ensemble import RandomForestRegressor\n", + "from sktime.performance_metrics.forecasting import MeanAbsolutePercentageError\n", + "from sktime.split import temporal_train_test_split\n", + "\n", + "# Format the data to be sktime-friendly\n", + "y_train, y_test, X_train, X_test = temporal_train_test_split(\n", + " y=X_y_WF1[\"Production\"].to_pandas(), X=X_WF1.drop([\"ID\", \"WF\", \"Time\"]).to_pandas()\n", + ")\n", + "\n", + "fh = np.arange(1, len(y_test) + 1) # forecasting horizon\n", + "regressor = RandomForestRegressor()\n", + "forecaster = make_reduction(\n", + " regressor,\n", + " strategy=\"recursive\",\n", + " window_length=12,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "efd0dc07", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
RecursiveTabularRegressionForecaster(estimator=RandomForestRegressor(),\n",
+       "                                     window_length=12)
Please rerun this cell to show the HTML repr or trust the notebook.
" + ], + "text/plain": [ + "RecursiveTabularRegressionForecaster(estimator=RandomForestRegressor(),\n", + " window_length=12)" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Takes a while\n", + "forecaster.fit(y=y_train, X=X_train, fh=fh)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "c76071c1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5531673844307267.0" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pred = forecaster.predict(fh=fh, X=X_test)\n", + "smape = MeanAbsolutePercentageError()\n", + "smape(y_test, y_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "c3c14c89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.Chart(...)" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Show predictions with test\n", + "df = pl.DataFrame({\"y_pred\": y_pred, \"y_test\": y_test}).with_row_index().melt(\"index\")\n", + "alt.Chart(df).mark_line().encode(x=\"index\", y=\"value\", color=\"variable\")\n", + "# It's not great" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "e423d93f", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.Chart(...)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "alt.Chart(X_y_WF1).mark_line().encode(x=\"Time\", y=\"Production\").properties(\n", + " width=2000, height=400\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "07318749", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.LayerChart(...)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Average production depending on the day of the week\n", + "(\n", + " alt.Chart(X_y_WF1.with_columns(pl.col(\"Time\").dt.weekday().alias(\"Day of week\")))\n", + " .mark_bar()\n", + " .encode(x=\"Day of week\", y=\"mean(Production)\")\n", + " + alt.Chart(X_y_WF1.with_columns(pl.col(\"Time\").dt.weekday().alias(\"Day of week\")))\n", + " .mark_errorbar(extent=\"iqr\")\n", + " .encode(x=\"Day of week\", y=\"Production\")\n", + ")\n", + "# 1 is Monday, 7 is Sunday\n", + "# Top production is on Mondays and Sundays, bottom is Thursdays\n", + "\n", + "# Error bars are the IQR" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "d2ecb957", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.LayerChart(...)" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Average production depending on month of the year\n", + "base = alt.Chart(X_y_WF1.with_columns(pl.col(\"Time\").dt.month().alias(\"Month\")))\n", + "(\n", + " base.mark_bar().encode(x=\"Month\", y=\"mean(Production)\")\n", + " + base.mark_errorbar(extent=\"iqr\").encode(x=\"Month\", y=\"Production\")\n", + ")\n", + "\n", + "# 1 is January, 12 is December\n", + "# Top production is January by far, bottom is August/September\n", + "# December is low, as mentioned earlier\n", + "\n", + "# The error bars show the inter-quartile range (bottom is 25% quantile, top is 75% quantile)\n", + "# This way we can clearly see that a lot of the data is very close to 0" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "accaa63b", + "metadata": {}, + "outputs": [], + "source": [ + "import polars.selectors as cs" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "e483d6d8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.Chart(...)" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nwp1 = X_y_WF1.select(cs.matches(\"Time\") | (cs.matches(\"NWP1\") & cs.matches(\"_U\")))\n", + "alt.Chart(nwp1.melt(id_vars=\"Time\")).mark_point().encode(\n", + " x=\"Time\", y=\"value\", color=\"variable\"\n", + ").properties(width=5000, height=500)\n", + "X_y_WF1.with_columns(\n", + " mean_U=pl.mean_horizontal((cs.matches(\"NWP1\") & cs.matches(\"_U\"))),\n", + " min_U=pl.min_horizontal((cs.matches(\"NWP1\") & cs.matches(\"_U\"))),\n", + " max_U=pl.max_horizontal((cs.matches(\"NWP1\") & cs.matches(\"_U\"))),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "05431927", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "
\n", + "" + ], + "text/plain": [ + "alt.LayerChart(...)" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(\n", + " alt.Chart(nwp1.melt(id_vars=\"Time\")).mark_line().encode(x=\"Time\", y=\"mean(value)\")\n", + " + alt.Chart(nwp1.melt(id_vars=\"Time\"))\n", + " .mark_errorband(extent=\"ci\")\n", + " .encode(x=\"Time\", y=\"value\")\n", + ").properties(width=5000, height=500)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "077057a3", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'X_y_WF1' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Correlation between the different variables\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[43mX_y_WF1\u001b[49m\n", + "\u001b[0;31mNameError\u001b[0m: name 'X_y_WF1' is not defined" + ] + } + ], + "source": [ + "# Correlation between the different variables\n", + "\n", + "X_y_WF1" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,py", + "main_language": "python" + }, + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/energy_forecasting.py b/examples/energy_forecasting.py new file mode 100644 index 00000000..f3cbf8f8 --- /dev/null +++ b/examples/energy_forecasting.py @@ -0,0 +1,204 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# formats: ipynb,py +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.1 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# + +import altair as alt +import polars as pl + +# Needed for larger datasets +alt.data_transformers.enable("vegafusion") + +# alt.renderers.enable("browser") +# - + +X = pl.read_csv("X_train.csv") +y = pl.read_csv("Y_train.csv") +X = X.with_columns(pl.col("Time").str.to_datetime("%d/%m/%Y %H:%M")) +X_y = X.join(y, on="ID") +X_WF1 = X.filter(pl.col("WF") == "WF1") +X_WF1.filter(pl.col("Time").dt.hour() == 0) +X_WF1["Time"].max() +X_y_WF1 = X_y.filter(pl.col("WF") == "WF1") + +alt.Chart(X_WF1).mark_point().encode( + x="Time", y="NWP1_00h_D-2_U", tooltip=alt.Tooltip("Time", format="%H:%M") +).interactive() +( + list( + X_WF1.filter(pl.col("NWP1_00h_D-2_U").is_not_null())["Time"].dt.hour().unique() + ), + list( + X_WF1.filter(pl.col("NWP1_06h_D-2_U").is_not_null())["Time"].dt.hour().unique() + ), + list( + X_WF1.filter(pl.col("NWP1_12h_D-2_U").is_not_null())["Time"].dt.hour().unique() + ), + list( + X_WF1.filter(pl.col("NWP1_18h_D-2_U").is_not_null())["Time"].dt.hour().unique() + ), +) +with pl.Config(tbl_rows=-1): + print(X_WF1) + +# + +# Complementary data +X_comp = pl.read_csv("WindFarms_complementary_data.csv", separator=";") +X_comp = X_comp.filter(pl.col("Time (UTC)").is_not_null()) +X_comp = X_comp.with_columns(pl.col("Time (UTC)").str.to_datetime("%d/%m/%Y %H:%M")) + +( + alt.Chart( + X_comp.filter( + (pl.col("Wind Farm") == "WF1") & (pl.col("Wind Turbine") == "TE1") + ).with_columns( + (pl.col("Wind direction (�)") - pl.col("Nacelle direction (�)")).alias( + "Nacelle misalignment (deg)" + ) + ) + ) + .mark_point() + .encode(x="Time (UTC)", y="Nacelle misalignment (deg)") +) + + +# + +# Statistics of power production +X_y_WF1["Production"].describe() + +# Histogram +( + alt.Chart(X_y_WF1) + .mark_bar() + .encode(alt.X("Production", bin=alt.Bin(step=0.5)), y="count()") + .properties(width=800) +) + +# The distribution of production is heavily right skewed. The median is 0.82 MW. +# According to Our World in Data 2017 (https://ourworldindata.org/scale-for-electricity), a French person consumes 0.019 MWh/day +# - + +# Total production for the month +( + alt.Chart(X_y_WF1) + .mark_line() + .encode(x="yearmonth(Time):T", y="sum(Production)") + .properties(width=800) +) + + +# + + +# There's a big drop in December 2018, compared to November and January. Is it because demand dropped, or because the data was corrupted, or because the wind farms were in maintenance? + +( + alt.Chart(X_y_WF1.filter(pl.col("Time").dt.month() == 12)) + .mark_point() + .encode(x="Time", y="Production") + .properties(width=3000) +) + +# The power production was very near zero for 9 consecutive days, from 12 December to 21 December. + +# + +# Trying out a classic sktime forecasting workflow + +import numpy as np +from sktime.forecasting.compose import make_reduction +from sklearn.ensemble import RandomForestRegressor +from sktime.performance_metrics.forecasting import MeanAbsolutePercentageError +from sktime.split import temporal_train_test_split + +# Format the data to be sktime-friendly +y_train, y_test, X_train, X_test = temporal_train_test_split( + y=X_y_WF1["Production"].to_pandas(), X=X_WF1.drop(["ID", "WF", "Time"]).to_pandas() +) + +fh = np.arange(1, len(y_test) + 1) # forecasting horizon +regressor = RandomForestRegressor() +forecaster = make_reduction( + regressor, + strategy="recursive", + window_length=12, +) +# - + +# Takes a while +forecaster.fit(y=y_train, X=X_train, fh=fh) + +y_pred = forecaster.predict(fh=fh, X=X_test) +smape = MeanAbsolutePercentageError() +smape(y_test, y_pred) + +# Show predictions with test +df = pl.DataFrame({"y_pred": y_pred, "y_test": y_test}).with_row_index().melt("index") +alt.Chart(df).mark_line().encode(x="index", y="value", color="variable") +# It's not great + +alt.Chart(X_y_WF1).mark_line().encode(x="Time", y="Production").properties( + width=2000, height=400 +) +# + +# Average production depending on the day of the week +( + alt.Chart(X_y_WF1.with_columns(pl.col("Time").dt.weekday().alias("Day of week"))) + .mark_bar() + .encode(x="Day of week", y="mean(Production)") + + alt.Chart(X_y_WF1.with_columns(pl.col("Time").dt.weekday().alias("Day of week"))) + .mark_errorbar(extent="iqr") + .encode(x="Day of week", y="Production") +) +# 1 is Monday, 7 is Sunday +# Top production is on Mondays and Sundays, bottom is Thursdays + +# Error bars are the IQR +# + +# Average production depending on month of the year +base = alt.Chart(X_y_WF1.with_columns(pl.col("Time").dt.month().alias("Month"))) +( + base.mark_bar().encode(x="Month", y="mean(Production)") + + base.mark_errorbar(extent="iqr").encode(x="Month", y="Production") +) + +# 1 is January, 12 is December +# Top production is January by far, bottom is August/September +# December is low, as mentioned earlier + +# The error bars show the inter-quartile range (bottom is 25% quantile, top is 75% quantile) +# This way we can clearly see that a lot of the data is very close to 0 +# - +import polars.selectors as cs + +nwp1 = X_y_WF1.select(cs.matches("Time") | (cs.matches("NWP1") & cs.matches("_U"))) +alt.Chart(nwp1.melt(id_vars="Time")).mark_point().encode( + x="Time", y="value", color="variable" +).properties(width=5000, height=500) +X_y_WF1.with_columns( + mean_U=pl.mean_horizontal((cs.matches("NWP1") & cs.matches("_U"))), + min_U=pl.min_horizontal((cs.matches("NWP1") & cs.matches("_U"))), + max_U=pl.max_horizontal((cs.matches("NWP1") & cs.matches("_U"))), +) + +( + alt.Chart(nwp1.melt(id_vars="Time")).mark_line().encode(x="Time", y="mean(value)") + + alt.Chart(nwp1.melt(id_vars="Time")) + .mark_errorband(extent="ci") + .encode(x="Time", y="value") +).properties(width=5000, height=500) + +# + +# Correlation between the different variables + +X_y_WF1