Skip to content

Commit

Permalink
Merge pull request #347 from openclimatefix/trigonometric_time
Browse files Browse the repository at this point in the history
Include trigonometric time features
  • Loading branch information
AUdaltsova authored Aug 6, 2024
2 parents 40f30a6 + 7703590 commit 3b7f5a6
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 2 deletions.
9 changes: 9 additions & 0 deletions ocf_datapipes/batch/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ class BatchKey(Enum):
gsp_x_osgb_fourier = auto()
gsp_time_utc_fourier = auto() # (batch_size, time, n_fourier_features)

# -------------- TIME -------------------------------------------
# Sine and cosine of date of year and time of day at every timestep.
# shape = (batch_size, n_timesteps)
# This is calculated for wind only inside datapipes.
wind_date_sin = auto()
wind_date_cos = auto()
wind_time_sin = auto()
wind_time_cos = auto()

# -------------- SUN --------------------------------------------
# Solar position at every timestep. shape = (batch_size, n_timesteps)
# The solar position data comes from two alternative sources: either the Sun pre-prepared
Expand Down
8 changes: 6 additions & 2 deletions ocf_datapipes/training/windnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,15 @@ def __iter__(self):
numpy_modalities.append(datapipes_dict["wind"].convert_wind_to_numpy_batch())

logger.debug("Combine all the data sources")
combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position(
logger.debug("Adding trigonometric date and time")
combined_datapipe = MergeNumpyModalities(numpy_modalities).add_trigonometric_date_time(
modality_name="wind"
)
# combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position(
# modality_name="wind"
# )

logger.info("Filtering out samples with no data")
# logger.info("Filtering out samples with no data")
# if self.check_satellite_no_zeros:
# in production we don't want any nans in the satellite data
# combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data)
Expand Down
1 change: 1 addition & 0 deletions ocf_datapipes/transform/numpy_batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

from .add_fourier_space_time import AddFourierSpaceTimeIterDataPipe as AddFourierSpaceTime
from .add_topographic_data import AddTopographicDataIterDataPipe as AddTopographicData
from .datetime_features import AddTrigonometricDateTimeIterDataPipe as AddTrigonometricDateTime
from .sun_position import AddSunPositionIterDataPipe as AddSunPosition
60 changes: 60 additions & 0 deletions ocf_datapipes/transform/numpy_batch/datetime_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Datapipes to trigonometric date and time to NumpyBatch"""

import numpy as np
from numpy.typing import NDArray
from torch.utils.data import IterDataPipe, functional_datapipe

from ocf_datapipes.batch import BatchKey


def _get_date_time_in_pi(
dt: NDArray[np.datetime64],
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
day_of_year = (dt - dt.astype("datetime64[Y]")).astype(int)
minute_of_day = (dt - dt.astype("datetime64[D]")).astype(int)

# converting into positions on sin-cos circle
time_in_pi = (2 * np.pi) * (minute_of_day / (24 * 3600))
date_in_pi = (2 * np.pi) * (day_of_year / (365 * 24 * 3600))

return date_in_pi, time_in_pi


@functional_datapipe("add_trigonometric_date_time")
class AddTrigonometricDateTimeIterDataPipe(IterDataPipe):
"""Adds the trigonometric encodings of date of year, time of day to the NumpyBatch"""

def __init__(self, source_datapipe: IterDataPipe, modality_name: str):
"""
Adds the sine and cosine of time to the NumpyBatch
Args:
source_datapipe: Datapipe of NumpyBatch
modality_name: Modality to add the time for
"""
self.source_datapipe = source_datapipe
self.modality_name = modality_name
assert self.modality_name in [
"wind",
], f"Trigonometric time not implemented for {self.modality_name}"

def __iter__(self):
for np_batch in self.source_datapipe:
time_utc = np_batch[BatchKey.wind_time_utc]

times: NDArray[np.datetime64] = time_utc.astype("datetime64[s]")

date_in_pi, time_in_pi = _get_date_time_in_pi(times)

# Store
date_sin_batch_key = BatchKey[self.modality_name + "_date_sin"]
date_cos_batch_key = BatchKey[self.modality_name + "_date_cos"]
time_sin_batch_key = BatchKey[self.modality_name + "_time_sin"]
time_cos_batch_key = BatchKey[self.modality_name + "_time_cos"]

np_batch[date_sin_batch_key] = np.sin(date_in_pi)
np_batch[date_cos_batch_key] = np.cos(date_in_pi)
np_batch[time_sin_batch_key] = np.sin(time_in_pi)
np_batch[time_cos_batch_key] = np.cos(time_in_pi)

yield np_batch
43 changes: 43 additions & 0 deletions tests/transform/numpy_batch/test_datetime_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np

from ocf_datapipes.transform.numpy_batch.datetime_features import _get_date_time_in_pi


def test_get_date_time_in_pi():
times = np.array(
[
"2020-01-01T00:00:00",
"2020-04-01T06:00:00",
"2020-07-01T12:00:00",
"2020-09-30T18:00:00",
"2020-12-31T23:59:59",
"2021-01-01T00:00:00",
"2021-04-02T06:00:00",
"2021-07-02T12:00:00",
"2021-10-01T18:00:00",
"2021-12-31T23:59:59",
]
).reshape((2, 5))

expected_times_in_pi = np.array([0, 0.5 * np.pi, np.pi, 1.5 * np.pi, 2 * np.pi] * 2).reshape(
(2, 5)
)

times = times.astype("datetime64[s]")

date_in_pi, time_in_pi = _get_date_time_in_pi(times)

# Note on precision: times are compared with tolerance equivalent to 1 second,
# dates are compared with tolerance equivalent to 5 minutes
# None of the data we use has a higher time resolution, so this is a good test of
# whether not accounting for leap years breaks things
assert np.isclose(np.cos(time_in_pi), np.cos(expected_times_in_pi), atol=7.3e-05).all()
assert np.isclose(np.sin(time_in_pi), np.sin(expected_times_in_pi), atol=7.3e-05).all()
assert np.isclose(np.cos(date_in_pi), np.cos(expected_times_in_pi), atol=0.02182).all()
assert np.isclose(np.sin(date_in_pi), np.sin(expected_times_in_pi), atol=0.02182).all()

# 1D array test
assert np.isclose(np.cos(time_in_pi[0]), np.cos(expected_times_in_pi[0]), atol=7.3e-05).all()
assert np.isclose(np.sin(time_in_pi[0]), np.sin(expected_times_in_pi[0]), atol=7.3e-05).all()
assert np.isclose(np.cos(date_in_pi[0]), np.cos(expected_times_in_pi[0]), atol=0.02182).all()
assert np.isclose(np.sin(date_in_pi[0]), np.sin(expected_times_in_pi[0]), atol=0.02182).all()

0 comments on commit 3b7f5a6

Please sign in to comment.