-
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #347 from openclimatefix/trigonometric_time
Include trigonometric time features
- Loading branch information
Showing
5 changed files
with
119 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Datapipes to trigonometric date and time to NumpyBatch""" | ||
|
||
import numpy as np | ||
from numpy.typing import NDArray | ||
from torch.utils.data import IterDataPipe, functional_datapipe | ||
|
||
from ocf_datapipes.batch import BatchKey | ||
|
||
|
||
def _get_date_time_in_pi( | ||
dt: NDArray[np.datetime64], | ||
) -> tuple[NDArray[np.float64], NDArray[np.float64]]: | ||
day_of_year = (dt - dt.astype("datetime64[Y]")).astype(int) | ||
minute_of_day = (dt - dt.astype("datetime64[D]")).astype(int) | ||
|
||
# converting into positions on sin-cos circle | ||
time_in_pi = (2 * np.pi) * (minute_of_day / (24 * 3600)) | ||
date_in_pi = (2 * np.pi) * (day_of_year / (365 * 24 * 3600)) | ||
|
||
return date_in_pi, time_in_pi | ||
|
||
|
||
@functional_datapipe("add_trigonometric_date_time") | ||
class AddTrigonometricDateTimeIterDataPipe(IterDataPipe): | ||
"""Adds the trigonometric encodings of date of year, time of day to the NumpyBatch""" | ||
|
||
def __init__(self, source_datapipe: IterDataPipe, modality_name: str): | ||
""" | ||
Adds the sine and cosine of time to the NumpyBatch | ||
Args: | ||
source_datapipe: Datapipe of NumpyBatch | ||
modality_name: Modality to add the time for | ||
""" | ||
self.source_datapipe = source_datapipe | ||
self.modality_name = modality_name | ||
assert self.modality_name in [ | ||
"wind", | ||
], f"Trigonometric time not implemented for {self.modality_name}" | ||
|
||
def __iter__(self): | ||
for np_batch in self.source_datapipe: | ||
time_utc = np_batch[BatchKey.wind_time_utc] | ||
|
||
times: NDArray[np.datetime64] = time_utc.astype("datetime64[s]") | ||
|
||
date_in_pi, time_in_pi = _get_date_time_in_pi(times) | ||
|
||
# Store | ||
date_sin_batch_key = BatchKey[self.modality_name + "_date_sin"] | ||
date_cos_batch_key = BatchKey[self.modality_name + "_date_cos"] | ||
time_sin_batch_key = BatchKey[self.modality_name + "_time_sin"] | ||
time_cos_batch_key = BatchKey[self.modality_name + "_time_cos"] | ||
|
||
np_batch[date_sin_batch_key] = np.sin(date_in_pi) | ||
np_batch[date_cos_batch_key] = np.cos(date_in_pi) | ||
np_batch[time_sin_batch_key] = np.sin(time_in_pi) | ||
np_batch[time_cos_batch_key] = np.cos(time_in_pi) | ||
|
||
yield np_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import numpy as np | ||
|
||
from ocf_datapipes.transform.numpy_batch.datetime_features import _get_date_time_in_pi | ||
|
||
|
||
def test_get_date_time_in_pi(): | ||
times = np.array( | ||
[ | ||
"2020-01-01T00:00:00", | ||
"2020-04-01T06:00:00", | ||
"2020-07-01T12:00:00", | ||
"2020-09-30T18:00:00", | ||
"2020-12-31T23:59:59", | ||
"2021-01-01T00:00:00", | ||
"2021-04-02T06:00:00", | ||
"2021-07-02T12:00:00", | ||
"2021-10-01T18:00:00", | ||
"2021-12-31T23:59:59", | ||
] | ||
).reshape((2, 5)) | ||
|
||
expected_times_in_pi = np.array([0, 0.5 * np.pi, np.pi, 1.5 * np.pi, 2 * np.pi] * 2).reshape( | ||
(2, 5) | ||
) | ||
|
||
times = times.astype("datetime64[s]") | ||
|
||
date_in_pi, time_in_pi = _get_date_time_in_pi(times) | ||
|
||
# Note on precision: times are compared with tolerance equivalent to 1 second, | ||
# dates are compared with tolerance equivalent to 5 minutes | ||
# None of the data we use has a higher time resolution, so this is a good test of | ||
# whether not accounting for leap years breaks things | ||
assert np.isclose(np.cos(time_in_pi), np.cos(expected_times_in_pi), atol=7.3e-05).all() | ||
assert np.isclose(np.sin(time_in_pi), np.sin(expected_times_in_pi), atol=7.3e-05).all() | ||
assert np.isclose(np.cos(date_in_pi), np.cos(expected_times_in_pi), atol=0.02182).all() | ||
assert np.isclose(np.sin(date_in_pi), np.sin(expected_times_in_pi), atol=0.02182).all() | ||
|
||
# 1D array test | ||
assert np.isclose(np.cos(time_in_pi[0]), np.cos(expected_times_in_pi[0]), atol=7.3e-05).all() | ||
assert np.isclose(np.sin(time_in_pi[0]), np.sin(expected_times_in_pi[0]), atol=7.3e-05).all() | ||
assert np.isclose(np.cos(date_in_pi[0]), np.cos(expected_times_in_pi[0]), atol=0.02182).all() | ||
assert np.isclose(np.sin(date_in_pi[0]), np.sin(expected_times_in_pi[0]), atol=0.02182).all() |