-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bfb555f
commit 0ad8b3a
Showing
3 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from pyaro.timeseries.AutoFilterReaderEngine import AutoFilterReader, AutoFilterEngine | ||
from pyaro.timeseries import Reader, Data, Station | ||
import polars | ||
import numpy as np | ||
|
||
|
||
class ParquetReaderException(Exception): | ||
pass | ||
|
||
|
||
class ParquetData(Data): | ||
def __init__(self, dataset: polars.DataFrame, variable: str): | ||
self._variable = variable | ||
self._dataset = dataset | ||
|
||
def __len__(self) -> int: | ||
return len(self._dataset) | ||
|
||
def slice(self, index): | ||
return ParquetData(self._dataset[index], self._variable) | ||
|
||
@property | ||
def altitudes(self): | ||
return self._dataset["altitude"].to_numpy() | ||
|
||
@property | ||
def start_times(self): | ||
return self._dataset["start_time"].to_numpy() | ||
|
||
@property | ||
def end_times(self): | ||
return self._dataset["end_time"].to_numpy() | ||
|
||
@property | ||
def flags(self): | ||
return self._dataset["flag"].to_numpy() | ||
|
||
def keys(self): | ||
return set(self._dataset.columns) - set(["variable", "units"]) | ||
|
||
@property | ||
def latitudes(self): | ||
return self._dataset["latitude"].to_numpy() | ||
|
||
@property | ||
def longitudes(self): | ||
return self._dataset["longitude"].to_numpy() | ||
|
||
@property | ||
def standard_deviations(self): | ||
return self._dataset["standard_deviation"].to_numpy() | ||
|
||
@property | ||
def stations(self): | ||
return self._dataset["station"].to_numpy() | ||
|
||
@property | ||
def values(self): | ||
return self._dataset["value"].to_numpy() | ||
|
||
|
||
class ParquetTimeseriesReader(AutoFilterReader): | ||
MANDATORY_COLUMNS = { | ||
"variable", | ||
"units", | ||
"value", | ||
"station", | ||
"longitude", | ||
"latitude", | ||
"start_time", | ||
"end_time", | ||
} | ||
OPTIONAL_COLUMNS = { | ||
"country": "", | ||
"flag": 0, | ||
"altitude": np.nan, | ||
"standard_deviation": np.nan, | ||
} | ||
|
||
def __init__(self, filename: str, filters): | ||
self._set_filters(filters) | ||
dataset = polars.read_parquet(filename) | ||
|
||
ds_cols = dataset.columns | ||
missing_mandatory = self.MANDATORY_COLUMNS - set(ds_cols) | ||
if len(missing_mandatory): | ||
raise ParquetReaderException( | ||
f"Expected the mandatory columns missing: {missing_mandatory}" | ||
) | ||
|
||
missing_optional = set(self.OPTIONAL_COLUMNS.keys()) - set(ds_cols) | ||
for missing in missing_optional: | ||
dataset = dataset.with_columns( | ||
polars.lit(self.OPTIONAL_COLUMNS[missing]).alias(missing) | ||
) | ||
|
||
self.dataset = dataset | ||
|
||
def _unfiltered_data(self, varname: str) -> ParquetData: | ||
return ParquetData( | ||
self.dataset.filter(polars.col("variable").eq(varname)), varname | ||
) | ||
|
||
def _unfiltered_stations(self) -> dict[str, Station]: | ||
ds = self.dataset.group_by("station").first() | ||
|
||
stations = dict() | ||
for row in ds.rows(named=True): | ||
stations[row["station"]] = Station( | ||
{ | ||
"station": row["station"], | ||
"longitude": row["longitude"], | ||
"latitude": row["latitude"], | ||
"altitude": row["altitude"], | ||
"country": row["country"], | ||
"url": "", | ||
"long_name": row["station"], | ||
} | ||
) | ||
return stations | ||
|
||
def _unfiltered_variables(self) -> list[str]: | ||
return list(self.dataset["variable"].unique()) | ||
|
||
def close(self): | ||
pass | ||
|
||
|
||
class ParquetTimeseriesEngine(AutoFilterEngine): | ||
def description(self) -> str: | ||
return """Parquet reader | ||
""" | ||
|
||
def url(self) -> str: | ||
return "https://github.com/metno/pyaro-readers" | ||
|
||
def reader_class(self) -> AutoFilterReader: | ||
return ParquetTimeseriesReader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from io import BytesIO | ||
|
||
import numpy as np | ||
import polars | ||
import pandas | ||
|
||
from pyaro_readers.parquet import ParquetTimeseriesReader | ||
|
||
|
||
def test_reading(): | ||
N = 1000 | ||
|
||
times = pandas.date_range("2025-02-28 00:00", freq="1h", periods=N + 1) | ||
|
||
ds_tmp = polars.DataFrame( | ||
{ | ||
"variable": "vespene", | ||
"units": "kg/m^3", | ||
"value": np.random.random(N), | ||
"station": "base", | ||
"longitude": 10, | ||
"latitude": 59, | ||
"start_time": times[:-1], | ||
"end_time": times[1:], | ||
} | ||
) | ||
tmpfile = BytesIO() | ||
ds_tmp.write_parquet(tmpfile) | ||
tmpfile.seek(0) | ||
|
||
ds = ParquetTimeseriesReader(tmpfile, filters=[]) | ||
|
||
stations = ds.stations() | ||
assert len(stations) == 1 | ||
station = stations["base"] | ||
assert station.longitude == 10 | ||
assert station.latitude == 59 | ||
assert station.long_name == "base" | ||
|
||
assert np.unique(ds.variables()) == ["vespene"] | ||
|
||
data = ds.data("vespene") | ||
|
||
assert np.all(data.start_times == times[:-1]) | ||
assert np.all(data.end_times == times[1:]) | ||
assert np.all(ds_tmp["value"].to_numpy() == data.values) | ||
|
||
data_slice = data[:500] | ||
assert len(data_slice) == 500 | ||
assert np.all(ds_tmp["value"][:500].to_numpy() == data_slice.values) |