Skip to content

Commit

Permalink
docs: add notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
augustebaum committed Jul 16, 2024
1 parent 4183f11 commit 8eea556
Show file tree
Hide file tree
Showing 2 changed files with 1,407 additions and 0 deletions.
1,203 changes: 1,203 additions & 0 deletions examples/energy_forecasting.ipynb

Large diffs are not rendered by default.

204 changes: 204 additions & 0 deletions examples/energy_forecasting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# formats: ipynb,py
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.1
# kernelspec:
# display_name: .venv
# language: python
# name: python3
# ---

# +
import altair as alt
import polars as pl

# Needed for larger datasets
alt.data_transformers.enable("vegafusion")

# alt.renderers.enable("browser")
# -

X = pl.read_csv("X_train.csv")
y = pl.read_csv("Y_train.csv")
X = X.with_columns(pl.col("Time").str.to_datetime("%d/%m/%Y %H:%M"))
X_y = X.join(y, on="ID")
X_WF1 = X.filter(pl.col("WF") == "WF1")
X_WF1.filter(pl.col("Time").dt.hour() == 0)
X_WF1["Time"].max()
X_y_WF1 = X_y.filter(pl.col("WF") == "WF1")

alt.Chart(X_WF1).mark_point().encode(
x="Time", y="NWP1_00h_D-2_U", tooltip=alt.Tooltip("Time", format="%H:%M")
).interactive()
(
list(
X_WF1.filter(pl.col("NWP1_00h_D-2_U").is_not_null())["Time"].dt.hour().unique()
),
list(
X_WF1.filter(pl.col("NWP1_06h_D-2_U").is_not_null())["Time"].dt.hour().unique()
),
list(
X_WF1.filter(pl.col("NWP1_12h_D-2_U").is_not_null())["Time"].dt.hour().unique()
),
list(
X_WF1.filter(pl.col("NWP1_18h_D-2_U").is_not_null())["Time"].dt.hour().unique()
),
)
with pl.Config(tbl_rows=-1):
print(X_WF1)

# +
# Complementary data
X_comp = pl.read_csv("WindFarms_complementary_data.csv", separator=";")
X_comp = X_comp.filter(pl.col("Time (UTC)").is_not_null())
X_comp = X_comp.with_columns(pl.col("Time (UTC)").str.to_datetime("%d/%m/%Y %H:%M"))

(
alt.Chart(
X_comp.filter(
(pl.col("Wind Farm") == "WF1") & (pl.col("Wind Turbine") == "TE1")
).with_columns(
(pl.col("Wind direction (�)") - pl.col("Nacelle direction (�)")).alias(
"Nacelle misalignment (deg)"
)
)
)
.mark_point()
.encode(x="Time (UTC)", y="Nacelle misalignment (deg)")
)


# +
# Statistics of power production
X_y_WF1["Production"].describe()

# Histogram
(
alt.Chart(X_y_WF1)
.mark_bar()
.encode(alt.X("Production", bin=alt.Bin(step=0.5)), y="count()")
.properties(width=800)
)

# The distribution of production is heavily right skewed. The median is 0.82 MW.
# According to Our World in Data 2017 (https://ourworldindata.org/scale-for-electricity), a French person consumes 0.019 MWh/day
# -

# Total production for the month
(
alt.Chart(X_y_WF1)
.mark_line()
.encode(x="yearmonth(Time):T", y="sum(Production)")
.properties(width=800)
)


# +

# There's a big drop in December 2018, compared to November and January. Is it because demand dropped, or because the data was corrupted, or because the wind farms were in maintenance?

(
alt.Chart(X_y_WF1.filter(pl.col("Time").dt.month() == 12))
.mark_point()
.encode(x="Time", y="Production")
.properties(width=3000)
)

# The power production was very near zero for 9 consecutive days, from 12 December to 21 December.

# +
# Trying out a classic sktime forecasting workflow

import numpy as np
from sktime.forecasting.compose import make_reduction
from sklearn.ensemble import RandomForestRegressor
from sktime.performance_metrics.forecasting import MeanAbsolutePercentageError
from sktime.split import temporal_train_test_split

# Format the data to be sktime-friendly
y_train, y_test, X_train, X_test = temporal_train_test_split(
y=X_y_WF1["Production"].to_pandas(), X=X_WF1.drop(["ID", "WF", "Time"]).to_pandas()
)

fh = np.arange(1, len(y_test) + 1) # forecasting horizon
regressor = RandomForestRegressor()
forecaster = make_reduction(
regressor,
strategy="recursive",
window_length=12,
)
# -

# Takes a while
forecaster.fit(y=y_train, X=X_train, fh=fh)

y_pred = forecaster.predict(fh=fh, X=X_test)
smape = MeanAbsolutePercentageError()
smape(y_test, y_pred)

# Show predictions with test
df = pl.DataFrame({"y_pred": y_pred, "y_test": y_test}).with_row_index().melt("index")
alt.Chart(df).mark_line().encode(x="index", y="value", color="variable")
# It's not great

alt.Chart(X_y_WF1).mark_line().encode(x="Time", y="Production").properties(
width=2000, height=400
)
# +
# Average production depending on the day of the week
(
alt.Chart(X_y_WF1.with_columns(pl.col("Time").dt.weekday().alias("Day of week")))
.mark_bar()
.encode(x="Day of week", y="mean(Production)")
+ alt.Chart(X_y_WF1.with_columns(pl.col("Time").dt.weekday().alias("Day of week")))
.mark_errorbar(extent="iqr")
.encode(x="Day of week", y="Production")
)
# 1 is Monday, 7 is Sunday
# Top production is on Mondays and Sundays, bottom is Thursdays

# Error bars are the IQR
# +
# Average production depending on month of the year
base = alt.Chart(X_y_WF1.with_columns(pl.col("Time").dt.month().alias("Month")))
(
base.mark_bar().encode(x="Month", y="mean(Production)")
+ base.mark_errorbar(extent="iqr").encode(x="Month", y="Production")
)

# 1 is January, 12 is December
# Top production is January by far, bottom is August/September
# December is low, as mentioned earlier

# The error bars show the inter-quartile range (bottom is 25% quantile, top is 75% quantile)
# This way we can clearly see that a lot of the data is very close to 0
# -
import polars.selectors as cs

nwp1 = X_y_WF1.select(cs.matches("Time") | (cs.matches("NWP1") & cs.matches("_U")))
alt.Chart(nwp1.melt(id_vars="Time")).mark_point().encode(
x="Time", y="value", color="variable"
).properties(width=5000, height=500)
X_y_WF1.with_columns(
mean_U=pl.mean_horizontal((cs.matches("NWP1") & cs.matches("_U"))),
min_U=pl.min_horizontal((cs.matches("NWP1") & cs.matches("_U"))),
max_U=pl.max_horizontal((cs.matches("NWP1") & cs.matches("_U"))),
)

(
alt.Chart(nwp1.melt(id_vars="Time")).mark_line().encode(x="Time", y="mean(value)")
+ alt.Chart(nwp1.melt(id_vars="Time"))
.mark_errorband(extent="ci")
.encode(x="Time", y="value")
).properties(width=5000, height=500)

# +
# Correlation between the different variables

X_y_WF1

0 comments on commit 8eea556

Please sign in to comment.