Skip to content

Commit

Permalink
bug fix for file retrieval per member
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Nov 30, 2024
1 parent dcc0b46 commit a3b3bde
Showing 1 changed file with 20 additions and 31 deletions.
51 changes: 20 additions & 31 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,7 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
# them separately
features = ["toa_downwelling_shortwave_flux", "open_water_fraction"]
das = [
self._get_single_timeseries_dataarray(
features=[feature], split=split
)
self._get_single_timeseries_dataarray(features=[feature], split=split)
for feature in features
]
da = xr.concat(das, dim="feature")
Expand All @@ -259,9 +257,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
# variable is turned into a dask array and so execution of the
# calculation is delayed until the feature values are actually
# used.
da_forecast_time = (
da.analysis_time + da.elapsed_forecast_duration
).chunk({"elapsed_forecast_duration": 1})
da_forecast_time = (da.analysis_time + da.elapsed_forecast_duration).chunk(
{"elapsed_forecast_duration": 1}
)
da_datetime_forcing_features = self._calc_datetime_forcing_features(
da_time=da_forecast_time
)
Expand Down Expand Up @@ -339,10 +337,7 @@ def _get_single_timeseries_dataarray(
for all categories of data
"""
if (
set(features).difference(self.get_vars_names(category="static"))
== set()
):
if set(features).difference(self.get_vars_names(category="static")) == set():
assert split in (
"train",
"val",
Expand All @@ -356,12 +351,8 @@ def _get_single_timeseries_dataarray(
"test",
), f"Unknown dataset split {split} for features {features}"

if member is not None and features != self.get_vars_names(
category="state"
):
raise ValueError(
"Member can only be specified for the 'state' category"
)
if member is not None and features != self.get_vars_names(category="state"):
raise ValueError("Member can only be specified for the 'state' category")

concat_axis = 0

Expand All @@ -377,9 +368,7 @@ def _get_single_timeseries_dataarray(
fp_samples = self.root_path / "samples" / split
if self._remove_state_features_with_index:
n_to_drop = len(self._remove_state_features_with_index)
feature_dim_mask = np.ones(
len(features) + n_to_drop, dtype=bool
)
feature_dim_mask = np.ones(len(features) + n_to_drop, dtype=bool)
feature_dim_mask[self._remove_state_features_with_index] = False
elif features == ["toa_downwelling_shortwave_flux"]:
filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT
Expand Down Expand Up @@ -445,7 +434,7 @@ def _get_single_timeseries_dataarray(
* np.timedelta64(1, "h")
)
elif d == "analysis_time":
coord_values = self._get_analysis_times(split=split)
coord_values = self._get_analysis_times(split=split, member_id=member)
elif d == "y":
coord_values = y
elif d == "x":
Expand All @@ -464,9 +453,7 @@ def _get_single_timeseries_dataarray(
if features_vary_with_analysis_time:
filepaths = [
fp_samples
/ filename_format.format(
analysis_time=analysis_time, **file_params
)
/ filename_format.format(analysis_time=analysis_time, **file_params)
for analysis_time in coords["analysis_time"]
]
else:
Expand Down Expand Up @@ -505,23 +492,29 @@ def _get_single_timeseries_dataarray(

return da

def _get_analysis_times(self, split) -> List[np.datetime64]:
def _get_analysis_times(self, split, member_id) -> List[np.datetime64]:
"""Get the analysis times for the given split by parsing the filenames
of all the files found for the given split.
Parameters
----------
split : str
The dataset split to get the analysis times for.
member_id : int
The ensemble member to get the analysis times for.
Returns
-------
List[dt.datetime]
The analysis times for the given split.
"""
if member_id is None:
# Only interior state data files have member_id, to avoid duplicates
# we only look at the first member for all other categories
member_id = 0
pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT)
pattern = re.sub(r"{member_id:[^}]*}", "*", pattern)
pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern)

sample_dir = self.root_path / "samples" / split
sample_files = sample_dir.glob(pattern)
Expand All @@ -531,9 +524,7 @@ def _get_analysis_times(self, split) -> List[np.datetime64]:
times.append(name_parts["analysis_time"])

if len(times) == 0:
raise ValueError(
f"No files found in {sample_dir} with pattern {pattern}"
)
raise ValueError(f"No files found in {sample_dir} with pattern {pattern}")

return times

Expand Down Expand Up @@ -690,9 +681,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""

def load_pickled_tensor(fn):
return torch.load(
self.root_path / "static" / fn, weights_only=True
).numpy()
return torch.load(self.root_path / "static" / fn, weights_only=True).numpy()

mean_diff_values = None
std_diff_values = None
Expand Down

0 comments on commit a3b3bde

Please sign in to comment.