forked from ECMWFCode4Earth/wildfire-forecasting
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfwi_reanalysis.py
184 lines (161 loc) · 6.15 KB
/
fwi_reanalysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
The dataset class to be used with fwi-forcings and fwi-reanalysis data.
"""
from glob import glob
import xarray as xr
import numpy as np
from dataloader.base_loader import ModelDataset as BaseDataset
from pytorch_lightning import _logger as log
class ModelDataset(BaseDataset):
"""
The dataset class responsible for loading the data and providing the samples for \
training.
:param BaseDataset: Base Dataset class to inherit from
:type BaseDataset: base_loader.BaseDataset
"""
def __init__(
self, forcings_dir=None, reanalysis_dir=None, hparams=None, **kwargs,
):
"""
Constructor for the ModelDataset class
:param forcings_dir: The directory containing the FWI-Forcings data, defaults \
to None
:type forcings_dir: str, optional
:param reanalysis_dir: The directory containing the FWI-Reanalysis data, \
to defaults to None
:type reanalysis_dir: str, optional
:param hparams: Holds configuration values, defaults to None
:type hparams: Namespace, optional
"""
super().__init__(
forcings_dir=forcings_dir,
reanalysis_dir=reanalysis_dir,
hparams=hparams,
**kwargs,
)
# Number of input and prediction days
assert (
self.hparams.in_days > 0 and self.hparams.out_days > 0
), "The number of input and output days must be > 0."
inp_files = glob(f"{forcings_dir}/ECMWF_FO_20*.nc")
out_files = glob(f"{reanalysis_dir}/ECMWF_FWI_20*_1200_hr_fwi_e5.nc")
# Consider only ground truth and discard forecast values
preprocess = lambda x: x.isel(time=slice(0, 1))
with xr.open_mfdataset(
inp_files,
preprocess=preprocess,
engine="h5netcdf",
parallel=False if self.hparams.dry_run else True,
combine="by_coords",
coords="minimal",
data_vars="minimal",
compat="override",
) as ds:
input_ = ds.sortby("time")
with xr.open_mfdataset(
out_files,
preprocess=preprocess,
engine="h5netcdf",
parallel=False if self.hparams.dry_run else True,
combine="by_coords",
coords="minimal",
data_vars="minimal",
compat="override",
) as ds:
self.output = ds.sortby("time")
if self.hparams.smos_input:
self.smos_files = glob(f"{self.hparams.smos_dir}/20*_20*.as1.grib")
with xr.open_mfdataset(
self.smos_files,
preprocess=lambda x: x.expand_dims("time"),
engine="cfgrib",
parallel=False,
) as ds:
smos_input = ds
# The t=0 dates
self.dates = []
for t in self.output.time.values:
t = t.astype("datetime64[D]")
if (
# Date is within the range if specified
(
not self.hparams.date_range
or self.hparams.date_range[0] <= t <= self.hparams.date_range[-1]
)
# Date is within the case-study range if specified
and (
not self.hparams.case_study_dates
or min([r[0] <= t <= r[-1] for r in self.hparams.case_study_dates])
)
# Input data for preceding dates is available
and (
all(
[
t - np.timedelta64(i, "D") in input_.time.values
for i in range(self.hparams.in_days)
]
)
and (
all(
[
t - np.timedelta64(i, "D") in smos_input.time.values
for i in range(self.hparams.in_days)
]
)
if self.hparams.smos_input
else True
)
)
# Output data for later dates is available
and all(
[
t + np.timedelta64(i, "D") in self.output.time.values
for i in range(self.hparams.out_days)
]
)
):
self.dates.append(t)
if self.hparams.dry_run and len(self.dates) == 4:
break
self.min_date = min(self.dates)
# Required output dates for operating on t=0 dates
out_dates_spread = list(
set(
sum(
[
[
d - np.timedelta64(i + 1 - self.hparams.out_days, "D")
for i in range(self.hparams.out_days)
]
for d in self.dates
],
[],
)
)
)
# Load the data only for required dates
self.output = self.output.sel(time=out_dates_spread).load()
if not self.hparams.benchmark:
# Required input dates for operating on t=0 dates
in_dates_spread = list(
set(
sum(
[
[
d + np.timedelta64(i + 1 - self.hparams.in_days, "D")
for i in range(self.hparams.in_days)
]
for d in self.dates
],
[],
)
)
)
self.input = input_.sel(time=in_dates_spread).load()
if self.hparams.smos_input:
smos_input = smos_input.sel(time=in_dates_spread, method="nearest")
# Drop duplicates
self.smos_input = smos_input.isel(
time=np.unique(smos_input["time"], return_index=True)[1]
).load()
log.info(f"\nTest Set Range: {min(self.dates)} to {max(self.dates)}")