diff --git a/setup.cfg b/setup.cfg index 960d575..1687d78 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,6 +58,7 @@ pyaro.timeseries = harp = pyaro_readers.harpreader:AeronetHARPEngine nilupmfabsorption = pyaro_readers.nilupmfabsorptionreader:NILUPMFAbsorptionTimeseriesEngine nilupmfebas = pyaro_readers.nilupmfebas:EbasPmfTimeseriesEngine + parquet = pyaro_readers.parquet:ParquetTimeseriesEngine diff --git a/src/pyaro_readers/parquet/__init__.py b/src/pyaro_readers/parquet/__init__.py new file mode 100644 index 0000000..13ed99f --- /dev/null +++ b/src/pyaro_readers/parquet/__init__.py @@ -0,0 +1,147 @@ +from pyaro.timeseries.AutoFilterReaderEngine import AutoFilterReader, AutoFilterEngine +from pyaro.timeseries import Reader, Data, Station +import polars +import numpy as np + + +class ParquetReaderException(Exception): + pass + + +class ParquetData(Data): + def __init__(self, dataset: polars.DataFrame, variable: str): + self._variable = variable + self._dataset = dataset + + def __len__(self) -> int: + return len(self._dataset) + + def slice(self, index): + return ParquetData(self._dataset[index], self._variable) + + @property + def altitudes(self): + return self._dataset["altitude"].to_numpy() + + @property + def start_times(self): + return self._dataset["start_time"].to_numpy() + + @property + def end_times(self): + return self._dataset["end_time"].to_numpy() + + @property + def flags(self): + return self._dataset["flag"].to_numpy() + + def keys(self): + return set(self._dataset.columns) - set(["variable", "units"]) + + @property + def latitudes(self): + return self._dataset["latitude"].to_numpy() + + @property + def longitudes(self): + return self._dataset["longitude"].to_numpy() + + @property + def standard_deviations(self): + return self._dataset["standard_deviation"].to_numpy() + + @property + def stations(self): + return self._dataset["station"].to_numpy() + + @property + def values(self): + return self._dataset["value"].to_numpy() + + @property + def units(self): + units = self._dataset["units"].unique() + if len(units) > 1: + raise ParquetReaderException( + f"This dataset contains more than one unit: {units}" + ) + return units[0] + + +class ParquetTimeseriesReader(AutoFilterReader): + MANDATORY_COLUMNS = { + "variable", + "units", + "value", + "station", + "longitude", + "latitude", + "start_time", + "end_time", + } + OPTIONAL_COLUMNS = { + "country": "", + "flag": 0, + "altitude": np.nan, + "standard_deviation": np.nan, + } + + def __init__(self, filename: str, filters): + self._set_filters(filters) + dataset = polars.read_parquet(filename) + + ds_cols = dataset.columns + missing_mandatory = self.MANDATORY_COLUMNS - set(ds_cols) + if len(missing_mandatory): + raise ParquetReaderException( + f"Expected the mandatory columns missing: {missing_mandatory}" + ) + + missing_optional = set(self.OPTIONAL_COLUMNS.keys()) - set(ds_cols) + for missing in missing_optional: + dataset = dataset.with_columns( + polars.lit(self.OPTIONAL_COLUMNS[missing]).alias(missing) + ) + + self.dataset = dataset + + def _unfiltered_data(self, varname: str) -> ParquetData: + return ParquetData( + self.dataset.filter(polars.col("variable").eq(varname)), varname + ) + + def _unfiltered_stations(self) -> dict[str, Station]: + ds = self.dataset.group_by("station").first() + + stations = dict() + for row in ds.rows(named=True): + stations[row["station"]] = Station( + { + "station": row["station"], + "longitude": row["longitude"], + "latitude": row["latitude"], + "altitude": row["altitude"], + "country": row["country"], + "url": "", + "long_name": row["station"], + } + ) + return stations + + def _unfiltered_variables(self) -> list[str]: + return list(self.dataset["variable"].unique()) + + def close(self): + pass + + +class ParquetTimeseriesEngine(AutoFilterEngine): + def description(self) -> str: + return """Parquet reader + """ + + def url(self) -> str: + return "https://github.com/metno/pyaro-readers" + + def reader_class(self) -> AutoFilterReader: + return ParquetTimeseriesReader diff --git a/tests/test_parquetreader.py b/tests/test_parquetreader.py new file mode 100644 index 0000000..0d5dcfd --- /dev/null +++ b/tests/test_parquetreader.py @@ -0,0 +1,51 @@ +from io import BytesIO + +import numpy as np +import polars +import pandas + +from pyaro_readers.parquet import ParquetTimeseriesReader + + +def test_reading(): + N = 1000 + + times = pandas.date_range("2025-02-28 00:00", freq="1h", periods=N + 1) + + ds_tmp = polars.DataFrame( + { + "variable": "vespene", + "units": "kg/m^3", + "value": np.random.random(N), + "station": "base", + "longitude": 10, + "latitude": 59, + "start_time": times[:-1], + "end_time": times[1:], + } + ) + tmpfile = BytesIO() + ds_tmp.write_parquet(tmpfile) + tmpfile.seek(0) + + ds = ParquetTimeseriesReader(tmpfile, filters=[]) + + stations = ds.stations() + assert len(stations) == 1 + station = stations["base"] + assert station.longitude == 10 + assert station.latitude == 59 + assert station.long_name == "base" + + assert np.unique(ds.variables()) == ["vespene"] + + data = ds.data("vespene") + + assert np.all(data.start_times == times[:-1]) + assert np.all(data.end_times == times[1:]) + assert np.all(ds_tmp["value"].to_numpy() == data.values) + + data_slice = data[:500] + assert len(data_slice) == 500 + assert np.all(ds_tmp["value"][:500].to_numpy() == data_slice.values) + assert data.units == "kg/m^3"